diff --git a/docs/source/en/model_doc/sam.mdx b/docs/source/en/model_doc/sam.mdx index 30228ce5865..70e93d2ae2c 100644 --- a/docs/source/en/model_doc/sam.mdx +++ b/docs/source/en/model_doc/sam.mdx @@ -64,6 +64,7 @@ scores = outputs.iou_scores Resources: - [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model +- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/automatic_mask_generation.ipynb) for using automatic mask generation pipeline. ## SamConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4a67cbed30e..6171da71905 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1012,6 +1012,7 @@ else: "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MASK_GENERATION_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_OBJECT_DETECTION_MAPPING", @@ -4650,6 +4651,7 @@ if TYPE_CHECKING: MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, + MODEL_FOR_MASK_GENERATION_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 4eccfded5b6..944b0ccb382 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -52,6 +52,7 @@ else: "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MASK_GENERATION_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_OBJECT_DETECTION_MAPPING", @@ -213,6 +214,7 @@ if TYPE_CHECKING: MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, + MODEL_FOR_MASK_GENERATION_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 393f192bb5d..33361568967 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -977,7 +977,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ] ) -MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES = OrderedDict( +MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( [ ("sam", "SamModel"), ] @@ -1058,9 +1058,11 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_F MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) -MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING = _LazyAutoMapping( - CONFIG_MAPPING_NAMES, MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES -) +MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) + + +class AutoModelForMaskGeneration(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING class AutoModel(_BaseAutoModelClass): diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 4c2aacac86a..361567f704e 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Image processor class for SAM.""" -from typing import Dict, List, Optional, Tuple, Union +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -26,16 +29,20 @@ from ...image_utils import ( ImageInput, PILImageResampling, get_image_size, + infer_channel_dimension_format, make_list_of_images, to_numpy_array, valid_images, ) -from ...utils import TensorType, is_torch_available, logging, requires_backends +from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends if is_torch_available(): + import torch import torch.nn.functional as F +if is_torchvision_available(): + from torchvision.ops.boxes import batched_nms logger = logging.get_logger(__name__) @@ -354,10 +361,14 @@ class SamImageProcessor(BaseImageProcessor): images = [self.pad_image(image=image, pad_size=pad_size) for image in images] images = [to_channel_dimension_format(image, data_format) for image in images] - - data = {"pixel_values": images, "original_sizes": original_sizes, "reshaped_input_sizes": reshaped_input_sizes} - encoded_outputs = BatchFeature(data=data, tensor_type=return_tensors) - + encoded_outputs = BatchFeature( + data={ + "pixel_values": images, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + }, + tensor_type=return_tensors, + ) return encoded_outputs def post_process_masks( @@ -392,11 +403,453 @@ class SamImageProcessor(BaseImageProcessor): for i, original_size in enumerate(original_sizes): interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] - interpolated_mask = F.interpolate( - interpolated_mask, [*original_size.numpy()], mode="bilinear", align_corners=False - ) + interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) if binarize: interpolated_mask = interpolated_mask > mask_threshold output_masks.append(interpolated_mask) return output_masks + + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`List[torch.Tensor]`): + List of all predicted segmentation masks + all_scores (`List[torch.Tensor]`): + List of all predicted iou scores + all_boxes (`List[torch.Tensor]`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + """ + return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) + + def generate_crop_boxes( + self, + image, + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[List[int]] = 1, + device: Optional["torch.device"] = None, + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`np.array`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sample from each crop. + crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + """ + return _generate_crop_boxes( + image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device + ) + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`torch.Tensor`): + Input masks. + iou_scores (`torch.Tensor`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["torch"]) + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppresion + masks = _mask_to_rle_pytorch(masks) + + return masks, scores, converted_boxes + + +def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecesary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + stability_scores = intersections / unions + return stability_scores + + +def _build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def _normalize_coordinates( + target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False +) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[List[int]] = 1, + device: Optional["torch.device"] = None, +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sample per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*): + Device to run the crop generation on. Defaults to CPU. + """ + if device is None: + device = torch.device("cpu") + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + image = to_numpy_array(image) + original_size = get_image_size(image) + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size + ) + + crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device) + point_grid_per_crop = np.array([point_grid_per_crop]) + points_per_crop = torch.tensor(point_grid_per_crop, device=device) + points_per_crop = points_per_crop.permute(0, 2, 1, 3) + + input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_size, original_size): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + + channel_dim = infer_channel_dimension_format(image) + if channel_dim == ChannelDimension.LAST: + cropped_im = image[top:bottom, left:right, :] + else: + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = get_image_size(cropped_im) + points_scale = np.array(cropped_im_size)[None, ::-1] + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return torch.nn.functional.pad(masks, pad, value=0) + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = torch.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def _batched_mask_to_box(masks: "torch.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = np.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose() # Reshape to original shape + + +def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maxium Suppression) on the outputs. + + Args: + rle_masks (`torch.Tensor`): + binary masks in the RLE format + iou_scores (`torch.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`torch.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index b4e69661388..84d461cd1ae 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline from .image_classification import ImageClassificationPipeline from .image_segmentation import ImageSegmentationPipeline from .image_to_text import ImageToTextPipeline +from .mask_generation import MaskGenerationPipeline from .object_detection import ObjectDetectionPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline @@ -124,6 +125,7 @@ if is_torch_available(): AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForMaskedLM, + AutoModelForMaskGeneration, AutoModelForObjectDetection, AutoModelForQuestionAnswering, AutoModelForSemanticSegmentation, @@ -384,6 +386,13 @@ SUPPORTED_TASKS = { "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}}, "type": "video", }, + "mask-generation": { + "impl": MaskGenerationPipeline, + "tf": (), + "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (), + "default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}}, + "type": "multimodal", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() @@ -536,6 +545,7 @@ def pipeline( - `"image-classification"`: will return a [`ImageClassificationPipeline`]. - `"image-segmentation"`: will return a [`ImageSegmentationPipeline`]. - `"image-to-text"`: will return a [`ImageToTextPipeline`]. + - `"mask-generation"`: will return a [`MaskGenerationPipeline`]. - `"object-detection"`: will return a [`ObjectDetectionPipeline`]. - `"question-answering"`: will return a [`QuestionAnsweringPipeline`]. - `"summarization"`: will return a [`SummarizationPipeline`]. diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index cdb597ef966..b728e94f34e 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side): tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value elif dim == 3: tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value + elif dim == 4: + tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value for i, item in enumerate(items): if dim == 2: @@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side): tensor[i, -len(item[key][0]) :, :] = item[key][0].clone() else: tensor[i, : len(item[key][0]), :] = item[key][0].clone() + elif dim == 4: + if padding_side == "left": + tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone() + else: + tensor[i, : len(item[key][0]), :, :] = item[key][0].clone() + return tensor else: return [item[key] for item in items] diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index db349944a4f..00f52af2872 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline): ) def _sanitize_parameters(self, **kwargs): - preprocessor_kwargs = {} + preprocess_kwargs = {} postprocess_kwargs = {} if "subtask" in kwargs: postprocess_kwargs["subtask"] = kwargs["subtask"] - preprocessor_kwargs["subtask"] = kwargs["subtask"] + preprocess_kwargs["subtask"] = kwargs["subtask"] if "threshold" in kwargs: postprocess_kwargs["threshold"] = kwargs["threshold"] if "mask_threshold" in kwargs: @@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline): if "overlap_mask_area_threshold" in kwargs: postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"] - return preprocessor_kwargs, {}, postprocess_kwargs + return preprocess_kwargs, {}, postprocess_kwargs def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]: """ diff --git a/src/transformers/pipelines/mask_generation.py b/src/transformers/pipelines/mask_generation.py new file mode 100644 index 00000000000..650d4a41416 --- /dev/null +++ b/src/transformers/pipelines/mask_generation.py @@ -0,0 +1,286 @@ +from collections import defaultdict +from typing import Optional + +from ..image_utils import load_image +from ..utils import ( + add_end_docstrings, + is_torch_available, + logging, + requires_backends, +) +from .base import PIPELINE_INIT_ARGS, ChunkPipeline + + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING + +logger = logging.get_logger(__name__) + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class MaskGenerationPipeline(ChunkPipeline): + """ + Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an + image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to + avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the + same time. Default is `64`. + + The pipeline works in 3 steps: + 1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point + labels. + For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes` + function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of + `points_per_batch`. + + 2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once. + Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the + tensors and models are on the same device. + + 3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps + are induced: + - image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks, + resizes them according + to the image size, and transforms there to binary masks. + - image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and + `stability_scores`. Also + applies a variety of filters based on non maximum suppression to remove bad masks. + - image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones. + + Arguments: + model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from + [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from + [`PreTrainedTokenizer`]. + feature_extractor ([`SequenceFeatureExtractor`]): + The feature extractor that will be used by the pipeline to encode the input. + points_per_batch (*optional*, int, default to 64): + Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU + memory. + output_bboxes_mask (`bool`, *optional*, default to `False`): + Whether or not to output the bounding box predictions. + output_rle_masks (`bool`, *optional*, default to `False`): + Whether or not to output the masks in `RLE` format + + Example: + + ```python + >>> from transformers import pipeline + + >>> generator = pipeline(model="facebook/sam-vit-h", task="mask-generation") + >>> outputs = generator( + ... "http://images.cocodataset.org/val2017/000000039769.jpg", + ... ) + + >>> outputs = generator( + ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128 + ... ) + ``` + + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"mask-generation"`. + + See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + requires_backends(self, "vision") + requires_backends(self, "torch") + + if self.framework != "pt": + raise ValueError(f"The {self.__class__} is only available in PyTorch.") + + self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING) + + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + postprocess_kwargs = {} + forward_params = {} + # preprocess args + if "points_per_batch" in kwargs: + preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"] + if "points_per_crop" in kwargs: + preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"] + if "crops_n_layers" in kwargs: + preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"] + if "crop_overlap_ratio" in kwargs: + preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"] + if "crop_n_points_downscale_factor" in kwargs: + preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"] + # postprocess args + if "pred_iou_thresh" in kwargs: + forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"] + if "stability_score_offset" in kwargs: + forward_params["stability_score_offset"] = kwargs["stability_score_offset"] + if "mask_threshold" in kwargs: + forward_params["mask_threshold"] = kwargs["mask_threshold"] + if "stability_score_thresh" in kwargs: + forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"] + if "crops_nms_thresh" in kwargs: + postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"] + if "output_rle_mask" in kwargs: + postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"] + if "output_bboxes_mask" in kwargs: + postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"] + return preprocess_kwargs, forward_params, postprocess_kwargs + + def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs): + """ + Generates binary segmentation masks + + Args: + inputs (`np.ndarray` or `bytes` or `str` or `dict`): + Image or list of images. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold to use when turning the predicted masks into binary values. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + A filtering threshold in `[0,1]` applied on the model's predicted mask quality. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to + binarize the model's mask predictions. + stability_score_offset (`int`, *optional*, defaults to 1): + The amount to shift the cutoff when calculated the stability score. + crops_nms_thresh (`float`, *optional*, defaults to 0.7): + The box IoU cutoff used by non-maximal suppression to filter duplicate masks. + crops_n_layers (`int`, *optional*, defaults to 0): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of + layers to run, where each layer has 2**i_layer number of image crops. + crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + + Return: + `Dict`: A dictionary with the following keys: + - **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width, + height)` of the original image. Returns a mask filled with zeros if no object is found. + - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of + the "object" described by the label and the mask. + + """ + return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs) + + def preprocess( + self, + image, + points_per_batch=64, + crops_n_layers: int = 0, + crop_overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[int] = 1, + ): + image = load_image(image) + target_size = self.image_processor.size["longest_edge"] + crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes( + image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor + ) + model_inputs = self.image_processor(images=cropped_images, return_tensors="pt") + + with self.device_placement(): + if self.framework == "pt": + inference_context = self.get_inference_context() + with inference_context(): + model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) + image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values")) + model_inputs["image_embeddings"] = image_embeddings + + n_points = grid_points.shape[1] + points_per_batch = points_per_batch if points_per_batch is not None else n_points + + if points_per_batch <= 0: + raise ValueError( + "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. " + "To return all points at once, set points_per_batch to None" + ) + + for i in range(0, n_points, points_per_batch): + batched_points = grid_points[:, i : i + points_per_batch, :, :] + labels = input_labels[:, i : i + points_per_batch] + is_last = i == n_points - points_per_batch + yield { + "input_points": batched_points, + "input_labels": labels, + "input_boxes": crop_boxes, + "is_last": is_last, + **model_inputs, + } + + def _forward( + self, + model_inputs, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + input_boxes = model_inputs.pop("input_boxes") + is_last = model_inputs.pop("is_last") + original_sizes = model_inputs.pop("original_sizes").tolist() + reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist() + + model_outputs = self.model(**model_inputs) + + # post processing happens here in order to avoid CPU GPU copies of ALL the masks + low_resolution_masks = model_outputs["pred_masks"] + masks = self.image_processor.post_process_masks( + low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False + ) + iou_scores = model_outputs["iou_scores"] + masks, iou_scores, boxes = self.image_processor.filter_masks( + masks[0], + iou_scores[0], + original_sizes[0], + input_boxes[0], + pred_iou_thresh, + stability_score_thresh, + mask_threshold, + stability_score_offset, + ) + return { + "masks": masks, + "is_last": is_last, + "boxes": boxes, + "iou_scores": iou_scores, + } + + def postprocess( + self, + model_outputs, + output_rle_mask=False, + output_bboxes_mask=False, + crops_nms_thresh=0.7, + ): + all_scores = [] + all_masks = [] + all_boxes = [] + for model_output in model_outputs: + all_scores.append(model_output.pop("iou_scores")) + all_masks.extend(model_output.pop("masks")) + all_boxes.append(model_output.pop("boxes")) + + all_scores = torch.cat(all_scores) + all_boxes = torch.cat(all_boxes) + output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation( + all_masks, all_scores, all_boxes, crops_nms_thresh + ) + + extra = defaultdict(list) + for output in model_outputs: + for k, v in output.items(): + extra[k].append(v) + + optional = {} + if output_rle_mask: + optional["rle_mask"] = rle_mask + + if output_bboxes_mask: + optional["bounding_boxes"] = bounding_boxes + + return {"masks": output_masks, "scores": iou_scores, **optional, **extra} diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4e3133b9b12..9b607e918e5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -475,6 +475,9 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None +MODEL_FOR_MASK_GENERATION_MAPPING = None + + MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index d78c5416379..a8e5461be8a 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) def test_inference_mask_generation_one_point_one_bb(self): - model = SamModel.from_pretrained("facebook/sam-vit-h") - processor = SamProcessor.from_pretrained("facebook/sam-vit-h") + model = SamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model.to(torch_device) model.eval() diff --git a/tests/pipelines/test_pipelines_mask_generation.py b/tests/pipelines/test_pipelines_mask_generation.py new file mode 100644 index 00000000000..f6ae6a2849f --- /dev/null +++ b/tests/pipelines/test_pipelines_mask_generation.py @@ -0,0 +1,142 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import unittest +from typing import Dict + +import numpy as np + +from transformers import MODEL_FOR_MASK_GENERATION_MAPPING, is_vision_available, pipeline +from transformers.pipelines import MaskGenerationPipeline +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_tf, + require_torch, + require_vision, + slow, +) + + +if is_vision_available(): + from PIL import Image + + +def hashimage(image: Image) -> str: + m = hashlib.md5(image.tobytes()) + return m.hexdigest()[:10] + + +def mask_to_test_readable(mask: Image) -> Dict: + npimg = np.array(mask) + shape = npimg.shape + return {"hash": hashimage(mask), "shape": shape} + + +@is_pipeline_test +@require_vision +@require_torch +class MaskGenerationPipelineTests(unittest.TestCase): + model_mapping = dict( + (list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else []) + ) + + def get_test_pipeline(self, model, tokenizer, processor): + image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor) + return image_segmenter, [ + "./tests/fixtures/tests_samples/COCO/000000039769.png", + "./tests/fixtures/tests_samples/COCO/000000039769.png", + ] + + @require_tf + @unittest.skip("Image segmentation not implemented in TF") + def test_small_model_tf(self): + pass + + @slow + @require_torch + def test_small_model_pt(self): + image_segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge") + + outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", points_per_batch=256) + + # Shortening by hashing + new_outupt = [] + for i, o in enumerate(outputs["masks"]): + new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}] + + # fmt: off + self.assertEqual( + nested_simplify(new_outupt, decimals=4), + [ + {'mask': {'hash': '115ad19f5f', 'shape': (480, 640)}, 'scores': 1.0444}, + {'mask': {'hash': '6affa964c6', 'shape': (480, 640)}, 'scores': 1.021}, + {'mask': {'hash': 'dfe28a0388', 'shape': (480, 640)}, 'scores': 1.0167}, + {'mask': {'hash': 'c0a5f4a318', 'shape': (480, 640)}, 'scores': 1.0132}, + {'mask': {'hash': 'fe8065c197', 'shape': (480, 640)}, 'scores': 1.0053}, + {'mask': {'hash': 'e2d0b7a0b7', 'shape': (480, 640)}, 'scores': 0.9967}, + {'mask': {'hash': '453c7844bd', 'shape': (480, 640)}, 'scores': 0.993}, + {'mask': {'hash': '3d44f2926d', 'shape': (480, 640)}, 'scores': 0.9909}, + {'mask': {'hash': '64033ddc3f', 'shape': (480, 640)}, 'scores': 0.9879}, + {'mask': {'hash': '801064ff79', 'shape': (480, 640)}, 'scores': 0.9834}, + {'mask': {'hash': '6172f276ef', 'shape': (480, 640)}, 'scores': 0.9716}, + {'mask': {'hash': 'b49e60e084', 'shape': (480, 640)}, 'scores': 0.9612}, + {'mask': {'hash': 'a811e775fd', 'shape': (480, 640)}, 'scores': 0.9599}, + {'mask': {'hash': 'a6a8ebcf4b', 'shape': (480, 640)}, 'scores': 0.9552}, + {'mask': {'hash': '9d8257e080', 'shape': (480, 640)}, 'scores': 0.9532}, + {'mask': {'hash': '32de6454a8', 'shape': (480, 640)}, 'scores': 0.9516}, + {'mask': {'hash': 'af3d4af2c8', 'shape': (480, 640)}, 'scores': 0.9499}, + {'mask': {'hash': '3c6db475fb', 'shape': (480, 640)}, 'scores': 0.9483}, + {'mask': {'hash': 'c290813fb9', 'shape': (480, 640)}, 'scores': 0.9464}, + {'mask': {'hash': 'b6f0b8f606', 'shape': (480, 640)}, 'scores': 0.943}, + {'mask': {'hash': '92ce16bfdf', 'shape': (480, 640)}, 'scores': 0.943}, + {'mask': {'hash': 'c749b25868', 'shape': (480, 640)}, 'scores': 0.9408}, + {'mask': {'hash': 'efb6cab859', 'shape': (480, 640)}, 'scores': 0.9335}, + {'mask': {'hash': '1ff2eafb30', 'shape': (480, 640)}, 'scores': 0.9326}, + {'mask': {'hash': '788b798e24', 'shape': (480, 640)}, 'scores': 0.9262}, + {'mask': {'hash': 'abea804f0e', 'shape': (480, 640)}, 'scores': 0.8999}, + {'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986}, + {'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984}, + {'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873}, + {'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871} + ], + ) + # fmt: on + + @require_torch + @slow + def test_threshold(self): + model_id = "facebook/sam-vit-huge" + image_segmenter = pipeline("mask-generation", model=model_id) + + outputs = image_segmenter( + "http://images.cocodataset.org/val2017/000000039769.jpg", pred_iou_thresh=1, points_per_batch=256 + ) + + # Shortening by hashing + new_outupt = [] + for i, o in enumerate(outputs["masks"]): + new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}] + + self.assertEqual( + nested_simplify(new_outupt, decimals=4), + [ + {"mask": {"hash": "115ad19f5f", "shape": (480, 640)}, "scores": 1.0444}, + {"mask": {"hash": "6affa964c6", "shape": (480, 640)}, "scores": 1.0210}, + {"mask": {"hash": "dfe28a0388", "shape": (480, 640)}, "scores": 1.0167}, + {"mask": {"hash": "c0a5f4a318", "shape": (480, 640)}, "scores": 1.0132}, + {"mask": {"hash": "fe8065c197", "shape": (480, 640)}, "scores": 1.0053}, + ], + ) diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 8c34bba5d6a..c34d8a39237 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -98,6 +98,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ), ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"), + ("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"), ]