mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add automatic-mask-generation
pipeline for Segment Anything Model (SAM) (#22840)
* cleanup * updates * more refactoring * make style * update inits * support other inputs in base * update based on review Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> * Update tests/pipelines/test_pipelines_automatic_mask_generation.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update * fixup * TODO x and y to refactor, _h _w refactored here * update docstring * more nits * style on these * more doc fix * rename variables * update * updates * style * update * fix `_mask_to_rle_pytorch` * styling * fix ask to rle, wrong outputs * add device arg * update * more updates, fix tets * udpate * update docstrings * styling * fixup * add notebook on the docs * update orginal sizes * fix docstring * updat condition on point_per-batch * updates tests * fix CI test * extend is required, append does not work! * fixup * fix CI tests * whit pixels left * address doc comments * fix doc * slow pipeline tests * update auto init * add revision * make fixup * update p!ipoeline tag when calling tests * alphabeitcal order in inits * fix copies * last style nits * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * reformat docstring * more reformat * address most of the comments * Update src/transformers/pipelines/mask_generation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final refactor * Update src/transformers/models/sam/image_processing_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixup and fix slow tests * revert --------- Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
e5f3487190
commit
f143037789
@ -64,6 +64,7 @@ scores = outputs.iou_scores
|
|||||||
Resources:
|
Resources:
|
||||||
|
|
||||||
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model
|
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model
|
||||||
|
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/automatic_mask_generation.ipynb) for using automatic mask generation pipeline.
|
||||||
|
|
||||||
## SamConfig
|
## SamConfig
|
||||||
|
|
||||||
|
@ -1012,6 +1012,7 @@ else:
|
|||||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
|
"MODEL_FOR_MASK_GENERATION_MAPPING",
|
||||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
||||||
@ -4650,6 +4651,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||||
|
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
@ -52,6 +52,7 @@ else:
|
|||||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
|
"MODEL_FOR_MASK_GENERATION_MAPPING",
|
||||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
||||||
@ -213,6 +214,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||||
|
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
@ -977,7 +977,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
("sam", "SamModel"),
|
("sam", "SamModel"),
|
||||||
]
|
]
|
||||||
@ -1058,9 +1058,11 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_F
|
|||||||
|
|
||||||
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
||||||
|
|
||||||
MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_AUTOMATIC_MASK_GENERATION_MAPPING_NAMES
|
|
||||||
)
|
|
||||||
|
class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(_BaseAutoModelClass):
|
class AutoModel(_BaseAutoModelClass):
|
||||||
|
@ -13,7 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for SAM."""
|
"""Image processor class for SAM."""
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from itertools import product
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -26,16 +29,20 @@ from ...image_utils import (
|
|||||||
ImageInput,
|
ImageInput,
|
||||||
PILImageResampling,
|
PILImageResampling,
|
||||||
get_image_size,
|
get_image_size,
|
||||||
|
infer_channel_dimension_format,
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, is_torch_available, logging, requires_backends
|
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from torchvision.ops.boxes import batched_nms
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -354,10 +361,14 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
images = [self.pad_image(image=image, pad_size=pad_size) for image in images]
|
images = [self.pad_image(image=image, pad_size=pad_size) for image in images]
|
||||||
|
|
||||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
encoded_outputs = BatchFeature(
|
||||||
data = {"pixel_values": images, "original_sizes": original_sizes, "reshaped_input_sizes": reshaped_input_sizes}
|
data={
|
||||||
encoded_outputs = BatchFeature(data=data, tensor_type=return_tensors)
|
"pixel_values": images,
|
||||||
|
"original_sizes": original_sizes,
|
||||||
|
"reshaped_input_sizes": reshaped_input_sizes,
|
||||||
|
},
|
||||||
|
tensor_type=return_tensors,
|
||||||
|
)
|
||||||
return encoded_outputs
|
return encoded_outputs
|
||||||
|
|
||||||
def post_process_masks(
|
def post_process_masks(
|
||||||
@ -392,11 +403,453 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
for i, original_size in enumerate(original_sizes):
|
for i, original_size in enumerate(original_sizes):
|
||||||
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
||||||
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
||||||
interpolated_mask = F.interpolate(
|
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
||||||
interpolated_mask, [*original_size.numpy()], mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
if binarize:
|
if binarize:
|
||||||
interpolated_mask = interpolated_mask > mask_threshold
|
interpolated_mask = interpolated_mask > mask_threshold
|
||||||
output_masks.append(interpolated_mask)
|
output_masks.append(interpolated_mask)
|
||||||
|
|
||||||
return output_masks
|
return output_masks
|
||||||
|
|
||||||
|
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
|
||||||
|
"""
|
||||||
|
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_masks (`List[torch.Tensor]`):
|
||||||
|
List of all predicted segmentation masks
|
||||||
|
all_scores (`List[torch.Tensor]`):
|
||||||
|
List of all predicted iou scores
|
||||||
|
all_boxes (`List[torch.Tensor]`):
|
||||||
|
List of all bounding boxes of the predicted masks
|
||||||
|
crops_nms_thresh (`float`):
|
||||||
|
Threshold for NMS (Non Maximum Suppression) algorithm.
|
||||||
|
"""
|
||||||
|
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||||
|
|
||||||
|
def generate_crop_boxes(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
target_size,
|
||||||
|
crop_n_layers: int = 0,
|
||||||
|
overlap_ratio: float = 512 / 1500,
|
||||||
|
points_per_crop: Optional[int] = 32,
|
||||||
|
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||||
|
device: Optional["torch.device"] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.array`):
|
||||||
|
Input original image
|
||||||
|
target_size (`int`):
|
||||||
|
Target size of the resized image
|
||||||
|
crop_n_layers (`int`, *optional*, defaults to 0):
|
||||||
|
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
|
||||||
|
each layer has 2**i_layer number of image crops.
|
||||||
|
overlap_ratio (`float`, *optional*, defaults to 512/1500):
|
||||||
|
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
||||||
|
the image length. Later layers with more crops scale down this overlap.
|
||||||
|
points_per_crop (`int`, *optional*, defaults to 32):
|
||||||
|
Number of points to sample from each crop.
|
||||||
|
crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
|
||||||
|
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||||
|
device (`torch.device`, *optional*, defaults to None):
|
||||||
|
Device to use for the computation. If None, cpu will be used.
|
||||||
|
"""
|
||||||
|
return _generate_crop_boxes(
|
||||||
|
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_masks(
|
||||||
|
self,
|
||||||
|
masks,
|
||||||
|
iou_scores,
|
||||||
|
original_size,
|
||||||
|
cropped_box_image,
|
||||||
|
pred_iou_thresh=0.88,
|
||||||
|
stability_score_thresh=0.95,
|
||||||
|
mask_threshold=0,
|
||||||
|
stability_score_offset=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||||
|
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||||
|
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||||
|
bounding boxes and pad the predicted masks if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks (`torch.Tensor`):
|
||||||
|
Input masks.
|
||||||
|
iou_scores (`torch.Tensor`):
|
||||||
|
List of IoU scores.
|
||||||
|
original_size (`Tuple[int,int]`):
|
||||||
|
Size of the orginal image.
|
||||||
|
cropped_box_image (`np.array`):
|
||||||
|
The cropped image.
|
||||||
|
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||||
|
The threshold for the iou scores.
|
||||||
|
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||||
|
The threshold for the stability score.
|
||||||
|
mask_threshold (`float`, *optional*, defaults to 0):
|
||||||
|
The threshold for the predicted masks.
|
||||||
|
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||||
|
The offset for the stability score used in the `_compute_stability_score` method.
|
||||||
|
|
||||||
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
original_height, original_width = original_size
|
||||||
|
iou_scores = iou_scores.flatten(0, 1)
|
||||||
|
masks = masks.flatten(0, 1)
|
||||||
|
|
||||||
|
if masks.shape[0] != iou_scores.shape[0]:
|
||||||
|
raise ValueError("masks and iou_scores must have the same batch size.")
|
||||||
|
|
||||||
|
if masks.device != iou_scores.device:
|
||||||
|
iou_scores = iou_scores.to(masks.device)
|
||||||
|
|
||||||
|
batch_size = masks.shape[0]
|
||||||
|
|
||||||
|
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
|
||||||
|
|
||||||
|
if pred_iou_thresh > 0.0:
|
||||||
|
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
||||||
|
|
||||||
|
# compute stability score
|
||||||
|
if stability_score_thresh > 0.0:
|
||||||
|
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
|
||||||
|
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||||
|
|
||||||
|
scores = iou_scores[keep_mask]
|
||||||
|
masks = masks[keep_mask]
|
||||||
|
|
||||||
|
# binarize masks
|
||||||
|
masks = masks > mask_threshold
|
||||||
|
converted_boxes = _batched_mask_to_box(masks)
|
||||||
|
|
||||||
|
keep_mask = ~_is_box_near_crop_edge(
|
||||||
|
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = scores[keep_mask]
|
||||||
|
masks = masks[keep_mask]
|
||||||
|
converted_boxes = converted_boxes[keep_mask]
|
||||||
|
|
||||||
|
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
|
||||||
|
# conversion to rle is necessary to run non-maximum suppresion
|
||||||
|
masks = _mask_to_rle_pytorch(masks)
|
||||||
|
|
||||||
|
return masks, scores, converted_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||||
|
# One mask is always contained inside the other.
|
||||||
|
# Save memory by preventing unnecesary cast to torch.int64
|
||||||
|
intersections = (
|
||||||
|
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||||
|
)
|
||||||
|
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||||
|
stability_scores = intersections / unions
|
||||||
|
return stability_scores
|
||||||
|
|
||||||
|
|
||||||
|
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
||||||
|
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||||
|
offset = 1 / (2 * n_per_side)
|
||||||
|
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||||
|
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||||
|
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||||
|
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||||
|
return points
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_coordinates(
|
||||||
|
target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
|
||||||
|
format.
|
||||||
|
"""
|
||||||
|
old_height, old_width = original_size
|
||||||
|
|
||||||
|
scale = target_size * 1.0 / max(old_height, old_width)
|
||||||
|
new_height, new_width = old_height * scale, old_width * scale
|
||||||
|
new_width = int(new_width + 0.5)
|
||||||
|
new_height = int(new_height + 0.5)
|
||||||
|
|
||||||
|
coords = deepcopy(coords).astype(float)
|
||||||
|
|
||||||
|
if is_bounding_box:
|
||||||
|
coords = coords.reshape(-1, 2, 2)
|
||||||
|
|
||||||
|
coords[..., 0] = coords[..., 0] * (new_width / old_width)
|
||||||
|
coords[..., 1] = coords[..., 1] * (new_height / old_height)
|
||||||
|
|
||||||
|
if is_bounding_box:
|
||||||
|
coords = coords.reshape(-1, 4)
|
||||||
|
|
||||||
|
return coords
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_crop_boxes(
|
||||||
|
image,
|
||||||
|
target_size: int, # Is it tuple here?
|
||||||
|
crop_n_layers: int = 0,
|
||||||
|
overlap_ratio: float = 512 / 1500,
|
||||||
|
points_per_crop: Optional[int] = 32,
|
||||||
|
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||||
|
device: Optional["torch.device"] = None,
|
||||||
|
) -> Tuple[List[List[int]], List[int]]:
|
||||||
|
"""
|
||||||
|
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
|
||||||
|
Image to generate crops for.
|
||||||
|
target_size (`int`):
|
||||||
|
Size of the smallest crop.
|
||||||
|
crop_n_layers (`int`, *optional*):
|
||||||
|
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
|
||||||
|
to run, where each layer has 2**i_layer number of image crops.
|
||||||
|
overlap_ratio (`int`, *optional*):
|
||||||
|
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
|
||||||
|
image length. Later layers with more crops scale down this overlap.
|
||||||
|
points_per_crop (`int`, *optional*):
|
||||||
|
Number of points to sample per crop.
|
||||||
|
crop_n_points_downscale_factor (`int`, *optional*):
|
||||||
|
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||||
|
device (`torch.device`, *optional*):
|
||||||
|
Device to run the crop generation on. Defaults to CPU.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
if isinstance(image, list):
|
||||||
|
raise ValueError("Only one image is allowed for crop generation.")
|
||||||
|
image = to_numpy_array(image)
|
||||||
|
original_size = get_image_size(image)
|
||||||
|
|
||||||
|
points_grid = []
|
||||||
|
for i in range(crop_n_layers + 1):
|
||||||
|
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
|
||||||
|
points_grid.append(_build_point_grid(n_points))
|
||||||
|
|
||||||
|
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
||||||
|
|
||||||
|
cropped_images, point_grid_per_crop = _generate_crop_images(
|
||||||
|
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||||
|
)
|
||||||
|
|
||||||
|
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
|
||||||
|
point_grid_per_crop = np.array([point_grid_per_crop])
|
||||||
|
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
|
||||||
|
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
||||||
|
"""
|
||||||
|
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
|
||||||
|
consists of the following required indices:
|
||||||
|
- X: X coordinate of the top left of the bounding box
|
||||||
|
- Y: Y coordinate of the top left of the bounding box
|
||||||
|
- W: width of the bounding box
|
||||||
|
- H: height of the bounding box
|
||||||
|
"""
|
||||||
|
crop_boxes, layer_idxs = [], []
|
||||||
|
im_height, im_width = original_size
|
||||||
|
short_side = min(im_height, im_width)
|
||||||
|
|
||||||
|
# Original image
|
||||||
|
crop_boxes.append([0, 0, im_width, im_height])
|
||||||
|
layer_idxs.append(0)
|
||||||
|
for i_layer in range(crop_n_layers):
|
||||||
|
n_crops_per_side = 2 ** (i_layer + 1)
|
||||||
|
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||||
|
|
||||||
|
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
|
||||||
|
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
|
||||||
|
|
||||||
|
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
|
||||||
|
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
|
||||||
|
|
||||||
|
for left, top in product(crop_box_x0, crop_box_y0):
|
||||||
|
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
|
||||||
|
crop_boxes.append(box)
|
||||||
|
layer_idxs.append(i_layer + 1)
|
||||||
|
|
||||||
|
return crop_boxes, layer_idxs
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_size, original_size):
|
||||||
|
"""
|
||||||
|
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
||||||
|
also passed.
|
||||||
|
"""
|
||||||
|
cropped_images = []
|
||||||
|
total_points_per_crop = []
|
||||||
|
for i, crop_box in enumerate(crop_boxes):
|
||||||
|
left, top, right, bottom = crop_box
|
||||||
|
|
||||||
|
channel_dim = infer_channel_dimension_format(image)
|
||||||
|
if channel_dim == ChannelDimension.LAST:
|
||||||
|
cropped_im = image[top:bottom, left:right, :]
|
||||||
|
else:
|
||||||
|
cropped_im = image[:, top:bottom, left:right]
|
||||||
|
|
||||||
|
cropped_images.append(cropped_im)
|
||||||
|
|
||||||
|
cropped_im_size = get_image_size(cropped_im)
|
||||||
|
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||||
|
|
||||||
|
points = points_grid[layer_idxs[i]] * points_scale
|
||||||
|
normalized_points = _normalize_coordinates(target_size, points, original_size)
|
||||||
|
total_points_per_crop.append(normalized_points)
|
||||||
|
|
||||||
|
return cropped_images, total_points_per_crop
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
||||||
|
left, top, right, bottom = crop_box
|
||||||
|
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
||||||
|
return masks
|
||||||
|
# Coordinate transform masks
|
||||||
|
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
||||||
|
pad = (left, pad_x - left, top, pad_y - top)
|
||||||
|
return torch.nn.functional.pad(masks, pad, value=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
||||||
|
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||||
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||||
|
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||||
|
|
||||||
|
left, top, _, _ = crop_box
|
||||||
|
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
|
||||||
|
# Check if boxes has a channel dimension
|
||||||
|
if len(boxes.shape) == 3:
|
||||||
|
offset = offset.unsqueeze(1)
|
||||||
|
boxes = (boxes + offset).float()
|
||||||
|
|
||||||
|
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||||
|
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||||
|
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||||
|
return torch.any(near_crop_edge, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _batched_mask_to_box(masks: "torch.Tensor"):
|
||||||
|
"""
|
||||||
|
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||||
|
corresponds the following required indices:
|
||||||
|
- LEFT: left hand side of the bounding box
|
||||||
|
- TOP: top of the bounding box
|
||||||
|
- RIGHT: right of the bounding box
|
||||||
|
- BOTTOM: bottom of the bounding box
|
||||||
|
|
||||||
|
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
||||||
|
is channel_1 x channel_2 x ... x 4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
|
||||||
|
"""
|
||||||
|
# torch.max below raises an error on empty inputs, just skip in this case
|
||||||
|
|
||||||
|
if torch.numel(masks) == 0:
|
||||||
|
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||||
|
|
||||||
|
# Normalize shape to Cxheightxwidth
|
||||||
|
shape = masks.shape
|
||||||
|
height, width = shape[-2:]
|
||||||
|
|
||||||
|
# Get top and bottom edges
|
||||||
|
in_height, _ = torch.max(masks, dim=-1)
|
||||||
|
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
|
||||||
|
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||||
|
in_height_coords = in_height_coords + height * (~in_height)
|
||||||
|
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||||
|
|
||||||
|
# Get left and right edges
|
||||||
|
in_width, _ = torch.max(masks, dim=-2)
|
||||||
|
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
|
||||||
|
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||||
|
in_width_coords = in_width_coords + width * (~in_width)
|
||||||
|
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||||
|
|
||||||
|
# If the mask is empty the right edge will be to the left of the left edge.
|
||||||
|
# Replace these boxes with [0, 0, 0, 0]
|
||||||
|
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||||
|
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||||
|
out = out * (~empty_filter).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Return to original shape
|
||||||
|
out = out.reshape(*shape[:-2], 4)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||||
|
"""
|
||||||
|
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||||
|
"""
|
||||||
|
# Put in fortran order and flatten height and width
|
||||||
|
batch_size, height, width = input_mask.shape
|
||||||
|
input_mask = input_mask.permute(0, 2, 1).flatten(1)
|
||||||
|
|
||||||
|
# Compute change indices
|
||||||
|
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
||||||
|
change_indices = diff.nonzero()
|
||||||
|
|
||||||
|
# Encode run length
|
||||||
|
out = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||||
|
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||||
|
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||||
|
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||||
|
out.append({"size": [height, width], "counts": counts})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""Compute a binary mask from an uncompressed RLE."""
|
||||||
|
height, width = rle["size"]
|
||||||
|
mask = np.empty(height * width, dtype=bool)
|
||||||
|
idx = 0
|
||||||
|
parity = False
|
||||||
|
for count in rle["counts"]:
|
||||||
|
mask[idx : idx + count] = parity
|
||||||
|
idx += count
|
||||||
|
parity = not parity
|
||||||
|
mask = mask.reshape(width, height)
|
||||||
|
return mask.transpose() # Reshape to original shape
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||||
|
"""
|
||||||
|
Perform NMS (Non Maxium Suppression) on the outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rle_masks (`torch.Tensor`):
|
||||||
|
binary masks in the RLE format
|
||||||
|
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
|
||||||
|
iou_scores predicted by the model
|
||||||
|
mask_boxes (`torch.Tensor`):
|
||||||
|
The bounding boxes corresponding to segmentation masks
|
||||||
|
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||||
|
NMS threshold.
|
||||||
|
"""
|
||||||
|
keep_by_nms = batched_nms(
|
||||||
|
boxes=mask_boxes.float(),
|
||||||
|
scores=iou_scores,
|
||||||
|
idxs=torch.zeros(mask_boxes.shape[0]),
|
||||||
|
iou_threshold=amg_crops_nms_thresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
iou_scores = iou_scores[keep_by_nms]
|
||||||
|
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
||||||
|
mask_boxes = mask_boxes[keep_by_nms]
|
||||||
|
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||||
|
|
||||||
|
return masks, iou_scores, rle_masks, mask_boxes
|
||||||
|
@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline
|
|||||||
from .image_classification import ImageClassificationPipeline
|
from .image_classification import ImageClassificationPipeline
|
||||||
from .image_segmentation import ImageSegmentationPipeline
|
from .image_segmentation import ImageSegmentationPipeline
|
||||||
from .image_to_text import ImageToTextPipeline
|
from .image_to_text import ImageToTextPipeline
|
||||||
|
from .mask_generation import MaskGenerationPipeline
|
||||||
from .object_detection import ObjectDetectionPipeline
|
from .object_detection import ObjectDetectionPipeline
|
||||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||||
@ -124,6 +125,7 @@ if is_torch_available():
|
|||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
AutoModelForImageSegmentation,
|
AutoModelForImageSegmentation,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
|
AutoModelForMaskGeneration,
|
||||||
AutoModelForObjectDetection,
|
AutoModelForObjectDetection,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSemanticSegmentation,
|
AutoModelForSemanticSegmentation,
|
||||||
@ -384,6 +386,13 @@ SUPPORTED_TASKS = {
|
|||||||
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
|
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
|
||||||
"type": "video",
|
"type": "video",
|
||||||
},
|
},
|
||||||
|
"mask-generation": {
|
||||||
|
"impl": MaskGenerationPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
|
||||||
|
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
|
||||||
|
"type": "multimodal",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||||
@ -536,6 +545,7 @@ def pipeline(
|
|||||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||||
|
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
|
||||||
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
|
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
|
||||||
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
|
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
|
||||||
- `"summarization"`: will return a [`SummarizationPipeline`].
|
- `"summarization"`: will return a [`SummarizationPipeline`].
|
||||||
|
@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side):
|
|||||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||||
elif dim == 3:
|
elif dim == 3:
|
||||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||||
|
elif dim == 4:
|
||||||
|
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
|
||||||
|
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
if dim == 2:
|
if dim == 2:
|
||||||
@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side):
|
|||||||
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
||||||
else:
|
else:
|
||||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||||
|
elif dim == 4:
|
||||||
|
if padding_side == "left":
|
||||||
|
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
|
||||||
|
else:
|
||||||
|
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
else:
|
else:
|
||||||
return [item[key] for item in items]
|
return [item[key] for item in items]
|
||||||
|
@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, **kwargs):
|
||||||
preprocessor_kwargs = {}
|
preprocess_kwargs = {}
|
||||||
postprocess_kwargs = {}
|
postprocess_kwargs = {}
|
||||||
if "subtask" in kwargs:
|
if "subtask" in kwargs:
|
||||||
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||||
preprocessor_kwargs["subtask"] = kwargs["subtask"]
|
preprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||||
if "threshold" in kwargs:
|
if "threshold" in kwargs:
|
||||||
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||||
if "mask_threshold" in kwargs:
|
if "mask_threshold" in kwargs:
|
||||||
@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline):
|
|||||||
if "overlap_mask_area_threshold" in kwargs:
|
if "overlap_mask_area_threshold" in kwargs:
|
||||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||||
|
|
||||||
return preprocessor_kwargs, {}, postprocess_kwargs
|
return preprocess_kwargs, {}, postprocess_kwargs
|
||||||
|
|
||||||
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||||
"""
|
"""
|
||||||
|
286
src/transformers/pipelines/mask_generation.py
Normal file
286
src/transformers/pipelines/mask_generation.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
from ..utils import (
|
||||||
|
add_end_docstrings,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class MaskGenerationPipeline(ChunkPipeline):
|
||||||
|
"""
|
||||||
|
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
|
||||||
|
image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
|
||||||
|
avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
|
||||||
|
same time. Default is `64`.
|
||||||
|
|
||||||
|
The pipeline works in 3 steps:
|
||||||
|
1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
|
||||||
|
labels.
|
||||||
|
For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
|
||||||
|
function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
|
||||||
|
`points_per_batch`.
|
||||||
|
|
||||||
|
2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
|
||||||
|
Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
|
||||||
|
tensors and models are on the same device.
|
||||||
|
|
||||||
|
3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
|
||||||
|
are induced:
|
||||||
|
- image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
|
||||||
|
resizes them according
|
||||||
|
to the image size, and transforms there to binary masks.
|
||||||
|
- image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
|
||||||
|
`stability_scores`. Also
|
||||||
|
applies a variety of filters based on non maximum suppression to remove bad masks.
|
||||||
|
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||||
|
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||||
|
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
|
||||||
|
tokenizer ([`PreTrainedTokenizer`]):
|
||||||
|
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||||
|
[`PreTrainedTokenizer`].
|
||||||
|
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||||
|
The feature extractor that will be used by the pipeline to encode the input.
|
||||||
|
points_per_batch (*optional*, int, default to 64):
|
||||||
|
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
|
||||||
|
memory.
|
||||||
|
output_bboxes_mask (`bool`, *optional*, default to `False`):
|
||||||
|
Whether or not to output the bounding box predictions.
|
||||||
|
output_rle_masks (`bool`, *optional*, default to `False`):
|
||||||
|
Whether or not to output the masks in `RLE` format
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
|
||||||
|
>>> generator = pipeline(model="facebook/sam-vit-h", task="mask-generation")
|
||||||
|
>>> outputs = generator(
|
||||||
|
... "http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> outputs = generator(
|
||||||
|
... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
|
||||||
|
... )
|
||||||
|
```
|
||||||
|
|
||||||
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||||
|
|
||||||
|
This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||||
|
`"mask-generation"`.
|
||||||
|
|
||||||
|
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
requires_backends(self, "vision")
|
||||||
|
requires_backends(self, "torch")
|
||||||
|
|
||||||
|
if self.framework != "pt":
|
||||||
|
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||||
|
|
||||||
|
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
preprocess_kwargs = {}
|
||||||
|
postprocess_kwargs = {}
|
||||||
|
forward_params = {}
|
||||||
|
# preprocess args
|
||||||
|
if "points_per_batch" in kwargs:
|
||||||
|
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
|
||||||
|
if "points_per_crop" in kwargs:
|
||||||
|
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
|
||||||
|
if "crops_n_layers" in kwargs:
|
||||||
|
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
|
||||||
|
if "crop_overlap_ratio" in kwargs:
|
||||||
|
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
|
||||||
|
if "crop_n_points_downscale_factor" in kwargs:
|
||||||
|
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
|
||||||
|
# postprocess args
|
||||||
|
if "pred_iou_thresh" in kwargs:
|
||||||
|
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
|
||||||
|
if "stability_score_offset" in kwargs:
|
||||||
|
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
|
||||||
|
if "mask_threshold" in kwargs:
|
||||||
|
forward_params["mask_threshold"] = kwargs["mask_threshold"]
|
||||||
|
if "stability_score_thresh" in kwargs:
|
||||||
|
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
|
||||||
|
if "crops_nms_thresh" in kwargs:
|
||||||
|
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
|
||||||
|
if "output_rle_mask" in kwargs:
|
||||||
|
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
|
||||||
|
if "output_bboxes_mask" in kwargs:
|
||||||
|
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
|
||||||
|
return preprocess_kwargs, forward_params, postprocess_kwargs
|
||||||
|
|
||||||
|
def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Generates binary segmentation masks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
||||||
|
Image or list of images.
|
||||||
|
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||||
|
Threshold to use when turning the predicted masks into binary values.
|
||||||
|
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||||
|
A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
|
||||||
|
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||||
|
A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
|
||||||
|
binarize the model's mask predictions.
|
||||||
|
stability_score_offset (`int`, *optional*, defaults to 1):
|
||||||
|
The amount to shift the cutoff when calculated the stability score.
|
||||||
|
crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||||
|
The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
|
||||||
|
crops_n_layers (`int`, *optional*, defaults to 0):
|
||||||
|
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
|
||||||
|
layers to run, where each layer has 2**i_layer number of image crops.
|
||||||
|
crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
|
||||||
|
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
||||||
|
the image length. Later layers with more crops scale down this overlap.
|
||||||
|
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
|
||||||
|
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`Dict`: A dictionary with the following keys:
|
||||||
|
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
|
||||||
|
height)` of the original image. Returns a mask filled with zeros if no object is found.
|
||||||
|
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
|
||||||
|
the "object" described by the label and the mask.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
points_per_batch=64,
|
||||||
|
crops_n_layers: int = 0,
|
||||||
|
crop_overlap_ratio: float = 512 / 1500,
|
||||||
|
points_per_crop: Optional[int] = 32,
|
||||||
|
crop_n_points_downscale_factor: Optional[int] = 1,
|
||||||
|
):
|
||||||
|
image = load_image(image)
|
||||||
|
target_size = self.image_processor.size["longest_edge"]
|
||||||
|
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
|
||||||
|
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
|
||||||
|
)
|
||||||
|
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
|
||||||
|
|
||||||
|
with self.device_placement():
|
||||||
|
if self.framework == "pt":
|
||||||
|
inference_context = self.get_inference_context()
|
||||||
|
with inference_context():
|
||||||
|
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
||||||
|
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
|
||||||
|
model_inputs["image_embeddings"] = image_embeddings
|
||||||
|
|
||||||
|
n_points = grid_points.shape[1]
|
||||||
|
points_per_batch = points_per_batch if points_per_batch is not None else n_points
|
||||||
|
|
||||||
|
if points_per_batch <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
|
||||||
|
"To return all points at once, set points_per_batch to None"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, n_points, points_per_batch):
|
||||||
|
batched_points = grid_points[:, i : i + points_per_batch, :, :]
|
||||||
|
labels = input_labels[:, i : i + points_per_batch]
|
||||||
|
is_last = i == n_points - points_per_batch
|
||||||
|
yield {
|
||||||
|
"input_points": batched_points,
|
||||||
|
"input_labels": labels,
|
||||||
|
"input_boxes": crop_boxes,
|
||||||
|
"is_last": is_last,
|
||||||
|
**model_inputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
model_inputs,
|
||||||
|
pred_iou_thresh=0.88,
|
||||||
|
stability_score_thresh=0.95,
|
||||||
|
mask_threshold=0,
|
||||||
|
stability_score_offset=1,
|
||||||
|
):
|
||||||
|
input_boxes = model_inputs.pop("input_boxes")
|
||||||
|
is_last = model_inputs.pop("is_last")
|
||||||
|
original_sizes = model_inputs.pop("original_sizes").tolist()
|
||||||
|
reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
|
||||||
|
|
||||||
|
model_outputs = self.model(**model_inputs)
|
||||||
|
|
||||||
|
# post processing happens here in order to avoid CPU GPU copies of ALL the masks
|
||||||
|
low_resolution_masks = model_outputs["pred_masks"]
|
||||||
|
masks = self.image_processor.post_process_masks(
|
||||||
|
low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
|
||||||
|
)
|
||||||
|
iou_scores = model_outputs["iou_scores"]
|
||||||
|
masks, iou_scores, boxes = self.image_processor.filter_masks(
|
||||||
|
masks[0],
|
||||||
|
iou_scores[0],
|
||||||
|
original_sizes[0],
|
||||||
|
input_boxes[0],
|
||||||
|
pred_iou_thresh,
|
||||||
|
stability_score_thresh,
|
||||||
|
mask_threshold,
|
||||||
|
stability_score_offset,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"masks": masks,
|
||||||
|
"is_last": is_last,
|
||||||
|
"boxes": boxes,
|
||||||
|
"iou_scores": iou_scores,
|
||||||
|
}
|
||||||
|
|
||||||
|
def postprocess(
|
||||||
|
self,
|
||||||
|
model_outputs,
|
||||||
|
output_rle_mask=False,
|
||||||
|
output_bboxes_mask=False,
|
||||||
|
crops_nms_thresh=0.7,
|
||||||
|
):
|
||||||
|
all_scores = []
|
||||||
|
all_masks = []
|
||||||
|
all_boxes = []
|
||||||
|
for model_output in model_outputs:
|
||||||
|
all_scores.append(model_output.pop("iou_scores"))
|
||||||
|
all_masks.extend(model_output.pop("masks"))
|
||||||
|
all_boxes.append(model_output.pop("boxes"))
|
||||||
|
|
||||||
|
all_scores = torch.cat(all_scores)
|
||||||
|
all_boxes = torch.cat(all_boxes)
|
||||||
|
output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
|
||||||
|
all_masks, all_scores, all_boxes, crops_nms_thresh
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = defaultdict(list)
|
||||||
|
for output in model_outputs:
|
||||||
|
for k, v in output.items():
|
||||||
|
extra[k].append(v)
|
||||||
|
|
||||||
|
optional = {}
|
||||||
|
if output_rle_mask:
|
||||||
|
optional["rle_mask"] = rle_mask
|
||||||
|
|
||||||
|
if output_bboxes_mask:
|
||||||
|
optional["bounding_boxes"] = bounding_boxes
|
||||||
|
|
||||||
|
return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
|
@ -475,6 +475,9 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
|
|||||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_MASK_GENERATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -464,8 +464,8 @@ class SamModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
|
||||||
|
|
||||||
def test_inference_mask_generation_one_point_one_bb(self):
|
def test_inference_mask_generation_one_point_one_bb(self):
|
||||||
model = SamModel.from_pretrained("facebook/sam-vit-h")
|
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-h")
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
142
tests/pipelines/test_pipelines_mask_generation.py
Normal file
142
tests/pipelines/test_pipelines_mask_generation.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import unittest
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import MODEL_FOR_MASK_GENERATION_MAPPING, is_vision_available, pipeline
|
||||||
|
from transformers.pipelines import MaskGenerationPipeline
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def hashimage(image: Image) -> str:
|
||||||
|
m = hashlib.md5(image.tobytes())
|
||||||
|
return m.hexdigest()[:10]
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_test_readable(mask: Image) -> Dict:
|
||||||
|
npimg = np.array(mask)
|
||||||
|
shape = npimg.shape
|
||||||
|
return {"hash": hashimage(mask), "shape": shape}
|
||||||
|
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
class MaskGenerationPipelineTests(unittest.TestCase):
|
||||||
|
model_mapping = dict(
|
||||||
|
(list(MODEL_FOR_MASK_GENERATION_MAPPING.items()) if MODEL_FOR_MASK_GENERATION_MAPPING else [])
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_test_pipeline(self, model, tokenizer, processor):
|
||||||
|
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor)
|
||||||
|
return image_segmenter, [
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@unittest.skip("Image segmentation not implemented in TF")
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
image_segmenter = pipeline("mask-generation", model="facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", points_per_batch=256)
|
||||||
|
|
||||||
|
# Shortening by hashing
|
||||||
|
new_outupt = []
|
||||||
|
for i, o in enumerate(outputs["masks"]):
|
||||||
|
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(new_outupt, decimals=4),
|
||||||
|
[
|
||||||
|
{'mask': {'hash': '115ad19f5f', 'shape': (480, 640)}, 'scores': 1.0444},
|
||||||
|
{'mask': {'hash': '6affa964c6', 'shape': (480, 640)}, 'scores': 1.021},
|
||||||
|
{'mask': {'hash': 'dfe28a0388', 'shape': (480, 640)}, 'scores': 1.0167},
|
||||||
|
{'mask': {'hash': 'c0a5f4a318', 'shape': (480, 640)}, 'scores': 1.0132},
|
||||||
|
{'mask': {'hash': 'fe8065c197', 'shape': (480, 640)}, 'scores': 1.0053},
|
||||||
|
{'mask': {'hash': 'e2d0b7a0b7', 'shape': (480, 640)}, 'scores': 0.9967},
|
||||||
|
{'mask': {'hash': '453c7844bd', 'shape': (480, 640)}, 'scores': 0.993},
|
||||||
|
{'mask': {'hash': '3d44f2926d', 'shape': (480, 640)}, 'scores': 0.9909},
|
||||||
|
{'mask': {'hash': '64033ddc3f', 'shape': (480, 640)}, 'scores': 0.9879},
|
||||||
|
{'mask': {'hash': '801064ff79', 'shape': (480, 640)}, 'scores': 0.9834},
|
||||||
|
{'mask': {'hash': '6172f276ef', 'shape': (480, 640)}, 'scores': 0.9716},
|
||||||
|
{'mask': {'hash': 'b49e60e084', 'shape': (480, 640)}, 'scores': 0.9612},
|
||||||
|
{'mask': {'hash': 'a811e775fd', 'shape': (480, 640)}, 'scores': 0.9599},
|
||||||
|
{'mask': {'hash': 'a6a8ebcf4b', 'shape': (480, 640)}, 'scores': 0.9552},
|
||||||
|
{'mask': {'hash': '9d8257e080', 'shape': (480, 640)}, 'scores': 0.9532},
|
||||||
|
{'mask': {'hash': '32de6454a8', 'shape': (480, 640)}, 'scores': 0.9516},
|
||||||
|
{'mask': {'hash': 'af3d4af2c8', 'shape': (480, 640)}, 'scores': 0.9499},
|
||||||
|
{'mask': {'hash': '3c6db475fb', 'shape': (480, 640)}, 'scores': 0.9483},
|
||||||
|
{'mask': {'hash': 'c290813fb9', 'shape': (480, 640)}, 'scores': 0.9464},
|
||||||
|
{'mask': {'hash': 'b6f0b8f606', 'shape': (480, 640)}, 'scores': 0.943},
|
||||||
|
{'mask': {'hash': '92ce16bfdf', 'shape': (480, 640)}, 'scores': 0.943},
|
||||||
|
{'mask': {'hash': 'c749b25868', 'shape': (480, 640)}, 'scores': 0.9408},
|
||||||
|
{'mask': {'hash': 'efb6cab859', 'shape': (480, 640)}, 'scores': 0.9335},
|
||||||
|
{'mask': {'hash': '1ff2eafb30', 'shape': (480, 640)}, 'scores': 0.9326},
|
||||||
|
{'mask': {'hash': '788b798e24', 'shape': (480, 640)}, 'scores': 0.9262},
|
||||||
|
{'mask': {'hash': 'abea804f0e', 'shape': (480, 640)}, 'scores': 0.8999},
|
||||||
|
{'mask': {'hash': '7b9e8ddb73', 'shape': (480, 640)}, 'scores': 0.8986},
|
||||||
|
{'mask': {'hash': 'cd24047c8a', 'shape': (480, 640)}, 'scores': 0.8984},
|
||||||
|
{'mask': {'hash': '6943e6bcbd', 'shape': (480, 640)}, 'scores': 0.8873},
|
||||||
|
{'mask': {'hash': 'b5f47c9191', 'shape': (480, 640)}, 'scores': 0.8871}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_threshold(self):
|
||||||
|
model_id = "facebook/sam-vit-huge"
|
||||||
|
image_segmenter = pipeline("mask-generation", model=model_id)
|
||||||
|
|
||||||
|
outputs = image_segmenter(
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg", pred_iou_thresh=1, points_per_batch=256
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shortening by hashing
|
||||||
|
new_outupt = []
|
||||||
|
for i, o in enumerate(outputs["masks"]):
|
||||||
|
new_outupt += [{"mask": mask_to_test_readable(o), "scores": outputs["scores"][i]}]
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(new_outupt, decimals=4),
|
||||||
|
[
|
||||||
|
{"mask": {"hash": "115ad19f5f", "shape": (480, 640)}, "scores": 1.0444},
|
||||||
|
{"mask": {"hash": "6affa964c6", "shape": (480, 640)}, "scores": 1.0210},
|
||||||
|
{"mask": {"hash": "dfe28a0388", "shape": (480, 640)}, "scores": 1.0167},
|
||||||
|
{"mask": {"hash": "c0a5f4a318", "shape": (480, 640)}, "scores": 1.0132},
|
||||||
|
{"mask": {"hash": "fe8065c197", "shape": (480, 640)}, "scores": 1.0053},
|
||||||
|
],
|
||||||
|
)
|
@ -98,6 +98,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||||||
),
|
),
|
||||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||||
|
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user