mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix image segmentation pipeline errors, resolve backward compatibility issues (#19768)
* Fix panoptic segmentation and pipeline * Update ImageSegmentationPipeline tests and reenable test_small_model_pt * Resolve backward compatibility issues
This commit is contained in:
parent
b58d4f70f6
commit
cca51aa151
@ -190,13 +190,13 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_
|
||||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||
|
||||
|
||||
def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8):
|
||||
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
|
||||
# Get the mask associated with the k class
|
||||
mask_k = mask_labels == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_area = (mask_probs[k] >= 0.5).sum()
|
||||
original_area = (mask_probs[k] >= mask_threshold).sum()
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
# Eliminate disconnected tiny segments
|
||||
@ -212,6 +212,7 @@ def compute_segments(
|
||||
mask_probs,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_size: Tuple[int, int] = None,
|
||||
@ -240,7 +241,9 @@ def compute_segments(
|
||||
should_fuse = pred_class in label_ids_to_fuse
|
||||
|
||||
# Check if mask exists and large enough to be a segment
|
||||
mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold)
|
||||
mask_exists, mask_k = check_segment_validity(
|
||||
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
|
||||
)
|
||||
|
||||
if mask_exists:
|
||||
if pred_class in stuff_memory_list:
|
||||
@ -1210,6 +1213,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
self,
|
||||
outputs,
|
||||
threshold: float = 0.5,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
return_coco_annotation: Optional[bool] = False,
|
||||
@ -1221,6 +1225,8 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
@ -1272,6 +1278,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
mask_threshold,
|
||||
overlap_mask_area_threshold,
|
||||
target_size,
|
||||
)
|
||||
@ -1287,6 +1294,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
self,
|
||||
outputs,
|
||||
threshold: float = 0.5,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
@ -1299,6 +1307,8 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
The outputs from [`DetrForSegmentation`].
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
@ -1359,6 +1369,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
mask_threshold,
|
||||
overlap_mask_area_threshold,
|
||||
label_ids_to_fuse,
|
||||
target_size,
|
||||
|
@ -37,15 +37,14 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.binary_mask_to_rle
|
||||
def binary_mask_to_rle(mask):
|
||||
"""
|
||||
Converts given binary mask of shape (height, width) to the run-length encoding (RLE) format.
|
||||
|
||||
Args:
|
||||
Converts given binary mask of shape (height, width) to the run-length encoding (RLE) format.
|
||||
mask (`torch.Tensor` or `numpy.array`):
|
||||
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
|
||||
segment_id or class_id.
|
||||
|
||||
Returns:
|
||||
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
|
||||
format.
|
||||
@ -60,6 +59,7 @@ def binary_mask_to_rle(mask):
|
||||
return [x for x in runs]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.convert_segmentation_to_rle
|
||||
def convert_segmentation_to_rle(segmentation):
|
||||
"""
|
||||
Converts given segmentation map of shape (height, width) to the run-length encoding (RLE) format.
|
||||
@ -67,7 +67,6 @@ def convert_segmentation_to_rle(segmentation):
|
||||
Args:
|
||||
segmentation (`torch.Tensor` or `numpy.array`):
|
||||
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
|
||||
|
||||
Returns:
|
||||
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
|
||||
"""
|
||||
@ -82,6 +81,7 @@ def convert_segmentation_to_rle(segmentation):
|
||||
return run_length_encodings
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.remove_low_and_no_objects
|
||||
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
||||
"""
|
||||
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
||||
@ -96,10 +96,8 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_
|
||||
A tensor of shape `(num_queries)`.
|
||||
object_mask_threshold (`float`):
|
||||
A number between 0 and 1 used to binarize the masks.
|
||||
|
||||
Raises:
|
||||
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
|
||||
|
||||
Returns:
|
||||
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
|
||||
< `object_mask_threshold`.
|
||||
@ -108,16 +106,18 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_
|
||||
raise ValueError("mask, scores and labels must have the same shape!")
|
||||
|
||||
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
|
||||
|
||||
return masks[to_keep], scores[to_keep], labels[to_keep]
|
||||
|
||||
|
||||
def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8):
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.check_segment_validity
|
||||
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
|
||||
# Get the mask associated with the k class
|
||||
mask_k = mask_labels == k
|
||||
mask_k_area = mask_k.sum()
|
||||
|
||||
# Compute the area of all the stuff in query k
|
||||
original_area = (mask_probs[k] >= 0.5).sum()
|
||||
original_area = (mask_probs[k] >= mask_threshold).sum()
|
||||
mask_exists = mask_k_area > 0 and original_area > 0
|
||||
|
||||
# Eliminate disconnected tiny segments
|
||||
@ -129,10 +129,12 @@ def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_thresho
|
||||
return mask_exists, mask_k
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.compute_segments
|
||||
def compute_segments(
|
||||
mask_probs,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_size: Tuple[int, int] = None,
|
||||
@ -144,7 +146,9 @@ def compute_segments(
|
||||
segments: List[Dict] = []
|
||||
|
||||
if target_size is not None:
|
||||
mask_probs = interpolate(mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False)[0]
|
||||
mask_probs = nn.functional.interpolate(
|
||||
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
|
||||
)[0]
|
||||
|
||||
current_segment_id = 0
|
||||
|
||||
@ -159,7 +163,9 @@ def compute_segments(
|
||||
should_fuse = pred_class in label_ids_to_fuse
|
||||
|
||||
# Check if mask exists and large enough to be a segment
|
||||
mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold)
|
||||
mask_exists, mask_k = check_segment_validity(
|
||||
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
|
||||
)
|
||||
|
||||
if mask_exists:
|
||||
if pred_class in stuff_memory_list:
|
||||
@ -722,6 +728,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
self,
|
||||
outputs,
|
||||
threshold: float = 0.5,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
return_coco_annotation: Optional[bool] = False,
|
||||
@ -735,6 +742,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
Raw outputs of the model.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
@ -786,6 +795,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
mask_threshold,
|
||||
overlap_mask_area_threshold,
|
||||
target_size,
|
||||
)
|
||||
@ -801,6 +811,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
self,
|
||||
outputs,
|
||||
threshold: float = 0.5,
|
||||
mask_threshold: float = 0.5,
|
||||
overlap_mask_area_threshold: float = 0.8,
|
||||
label_ids_to_fuse: Optional[Set[int]] = None,
|
||||
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
@ -814,6 +825,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
The outputs from [`MaskFormerForInstanceSegmentation`].
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability score threshold to keep predicted instance masks.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
||||
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
||||
instance mask.
|
||||
@ -875,6 +888,7 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
|
||||
mask_probs_item,
|
||||
pred_scores_item,
|
||||
pred_labels_item,
|
||||
mask_threshold,
|
||||
overlap_mask_area_threshold,
|
||||
label_ids_to_fuse,
|
||||
target_size,
|
||||
|
@ -60,6 +60,8 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
postprocess_kwargs["task"] = kwargs["task"]
|
||||
if "threshold" in kwargs:
|
||||
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||
if "mask_threshold" in kwargs:
|
||||
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
|
||||
if "overlap_mask_area_threshold" in kwargs:
|
||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||
return {}, {}, postprocess_kwargs
|
||||
@ -78,11 +80,13 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
|
||||
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
||||
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||
task (`str`, defaults to `semantic`):
|
||||
subtask (`str`, defaults to `panoptic`):
|
||||
Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
|
||||
capabilities.
|
||||
threshold (`float`, *optional*, defaults to 0.9):
|
||||
Probability threshold to filter out predicted masks.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Mask overlap threshold to eliminate small, disconnected segments.
|
||||
|
||||
@ -116,11 +120,16 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
model_outputs["target_size"] = target_size
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, task="semantic", threshold=0.9, overlap_mask_area_threshold=0.5):
|
||||
if task == "instance" and hasattr(self.feature_extractor, "post_process_instance_segmentation"):
|
||||
def postprocess(
|
||||
self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
|
||||
):
|
||||
if (subtask == "panoptic" or subtask is None) and hasattr(
|
||||
self.feature_extractor, "post_process_panoptic_segmentation"
|
||||
):
|
||||
outputs = self.feature_extractor.post_process_panoptic_segmentation(
|
||||
model_outputs,
|
||||
threshold=threshold,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_sizes=model_outputs["target_size"],
|
||||
)[0]
|
||||
@ -130,29 +139,7 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
|
||||
if len(outputs["segments_info"]) == 0:
|
||||
mask = Image.fromarray(np.zeros(segmentation.shape).astype(np.uint8), mode="L")
|
||||
annotation.append({"mask": mask, "label": None, "score": 0.0})
|
||||
else:
|
||||
for segment in outputs["segments_info"]:
|
||||
mask = (segmentation == segment["id"]) * 255
|
||||
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
|
||||
label = self.model.config.id2label[segment["label_id"]]
|
||||
score = segment["score"]
|
||||
annotation.append({"mask": mask, "label": label, "score": score})
|
||||
|
||||
elif task == "panoptic" and hasattr(self.feature_extractor, "post_process_panoptic_segmentation"):
|
||||
outputs = self.feature_extractor.post_process_panoptic_segmentation(
|
||||
model_outputs,
|
||||
threshold=threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_sizes=model_outputs["target_size"],
|
||||
)[0]
|
||||
|
||||
annotation = []
|
||||
segmentation = outputs["segmentation"]
|
||||
|
||||
if len(outputs["segments_info"]) == 0:
|
||||
mask = Image.fromarray(np.zeros(segmentation.shape).astype(np.uint8), mode="L")
|
||||
annotation.append({"mask": mask, "label": None, "score": 0.0})
|
||||
annotation.append({"mask": mask, "label": "NULL", "score": 0.0})
|
||||
else:
|
||||
for segment in outputs["segments_info"]:
|
||||
mask = (segmentation == segment["id"]) * 255
|
||||
@ -161,7 +148,34 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
score = segment["score"]
|
||||
annotation.append({"score": score, "label": label, "mask": mask})
|
||||
|
||||
elif task == "semantic" and hasattr(self.feature_extractor, "post_process_semantic_segmentation"):
|
||||
elif (subtask == "instance" or subtask is None) and hasattr(
|
||||
self.feature_extractor, "post_process_instance_segmentation"
|
||||
):
|
||||
outputs = self.feature_extractor.post_process_instance_segmentation(
|
||||
model_outputs,
|
||||
threshold=threshold,
|
||||
mask_threshold=mask_threshold,
|
||||
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
||||
target_sizes=model_outputs["target_size"],
|
||||
)[0]
|
||||
|
||||
annotation = []
|
||||
segmentation = outputs["segmentation"]
|
||||
|
||||
if len(outputs["segments_info"]) == 0:
|
||||
mask = Image.fromarray(np.zeros(segmentation.shape).astype(np.uint8), mode="L")
|
||||
annotation.append({"mask": mask, "label": "NULL", "score": 0.0})
|
||||
else:
|
||||
for segment in outputs["segments_info"]:
|
||||
mask = (segmentation == segment["id"]) * 255
|
||||
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
|
||||
label = self.model.config.id2label[segment["label_id"]]
|
||||
score = segment["score"]
|
||||
annotation.append({"mask": mask, "label": label, "score": score})
|
||||
|
||||
elif (subtask == "semantic" or subtask is None) and hasattr(
|
||||
self.feature_extractor, "post_process_semantic_segmentation"
|
||||
):
|
||||
outputs = self.feature_extractor.post_process_semantic_segmentation(
|
||||
model_outputs, target_sizes=model_outputs["target_size"]
|
||||
)[0]
|
||||
@ -176,5 +190,5 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
label = self.model.config.id2label[label]
|
||||
annotation.append({"score": None, "label": label, "mask": mask})
|
||||
else:
|
||||
raise ValueError(f"task {task} is not supported for model {self.model}")
|
||||
raise ValueError(f"Task {subtask} is not supported for model {self.model}.s")
|
||||
return annotation
|
||||
|
@ -399,13 +399,11 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
|
||||
|
||||
self.assertEqual(segmentation[0].shape, target_sizes[0])
|
||||
|
||||
@unittest.skip("Fix me Alara!")
|
||||
def test_post_process_panoptic_segmentation(self):
|
||||
feature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
|
||||
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
|
||||
segmentation = feature_extractor.post_process_panoptic_segmentation(outputs, threshold=0)
|
||||
print(len(segmentation))
|
||||
print(self.feature_extract_tester.batch_size)
|
||||
|
||||
self.assertTrue(len(segmentation) == self.feature_extract_tester.batch_size)
|
||||
for el in segmentation:
|
||||
self.assertTrue("segmentation" in el)
|
||||
|
@ -81,7 +81,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
]
|
||||
|
||||
def run_pipeline_test(self, image_segmenter, examples):
|
||||
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||
outputs = image_segmenter(
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
threshold=0.0,
|
||||
mask_threshold=0,
|
||||
overlap_mask_area_threshold=0,
|
||||
)
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
if isinstance(image_segmenter.model, (MaskFormerForInstanceSegmentation)):
|
||||
@ -97,15 +102,15 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
|
||||
# RGBA
|
||||
outputs = image_segmenter(dataset[0]["file"])
|
||||
outputs = image_segmenter(dataset[0]["file"], threshold=0.0, mask_threshold=0, overlap_mask_area_threshold=0)
|
||||
m = len(outputs)
|
||||
self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * m, outputs)
|
||||
# LA
|
||||
outputs = image_segmenter(dataset[1]["file"])
|
||||
outputs = image_segmenter(dataset[1]["file"], threshold=0.0, mask_threshold=0, overlap_mask_area_threshold=0)
|
||||
m = len(outputs)
|
||||
self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * m, outputs)
|
||||
# L
|
||||
outputs = image_segmenter(dataset[2]["file"])
|
||||
outputs = image_segmenter(dataset[2]["file"], threshold=0.0, mask_threshold=0, overlap_mask_area_threshold=0)
|
||||
m = len(outputs)
|
||||
self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * m, outputs)
|
||||
|
||||
@ -126,7 +131,9 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
]
|
||||
outputs = image_segmenter(batch, threshold=0.0, batch_size=batch_size)
|
||||
outputs = image_segmenter(
|
||||
batch, threshold=0.0, mask_threshold=0, overlap_mask_area_threshold=0, batch_size=batch_size
|
||||
)
|
||||
self.assertEqual(len(batch), len(outputs))
|
||||
self.assertEqual(len(outputs[0]), n)
|
||||
self.assertEqual(
|
||||
@ -152,55 +159,29 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
model = AutoModelForImageSegmentation.from_pretrained(model_id)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
image_segmenter = ImageSegmentationPipeline(
|
||||
model=model,
|
||||
feature_extractor=feature_extractor,
|
||||
task="semantic",
|
||||
threshold=0.0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
subtask="panoptic",
|
||||
threshold=0.0,
|
||||
mask_threshold=0.0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
# This is extremely brittle, and those values are made specific for the CI.
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{
|
||||
"label": "LABEL_88",
|
||||
"mask": {"hash": "7f0bf661a4", "shape": (480, 640), "white_pixels": 3},
|
||||
"score": None,
|
||||
},
|
||||
{
|
||||
"label": "LABEL_101",
|
||||
"mask": {"hash": "10ab738dc9", "shape": (480, 640), "white_pixels": 8948},
|
||||
"score": None,
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": {"hash": "b431e0946c", "shape": (480, 640), "white_pixels": 298249},
|
||||
"score": None,
|
||||
"mask": {"hash": "a01498ca7c", "shape": (480, 640), "white_pixels": 307200},
|
||||
},
|
||||
]
|
||||
# Temporary: Keeping around the old values as they might provide useful later
|
||||
# [
|
||||
# {
|
||||
# "score": 0.004,
|
||||
# "label": "LABEL_215",
|
||||
# "mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
# },
|
||||
# {
|
||||
# "score": 0.004,
|
||||
# "label": "LABEL_215",
|
||||
# "mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
# },
|
||||
# ],
|
||||
],
|
||||
)
|
||||
|
||||
outputs = image_segmenter(
|
||||
@ -209,6 +190,8 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
threshold=0.0,
|
||||
mask_threshold=0.0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
for output in outputs:
|
||||
for o in output:
|
||||
@ -219,62 +202,18 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
[
|
||||
[
|
||||
{
|
||||
"label": "LABEL_88",
|
||||
"mask": {"hash": "7f0bf661a4", "shape": (480, 640), "white_pixels": 3},
|
||||
"score": None,
|
||||
},
|
||||
{
|
||||
"label": "LABEL_101",
|
||||
"mask": {"hash": "10ab738dc9", "shape": (480, 640), "white_pixels": 8948},
|
||||
"score": None,
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": {"hash": "b431e0946c", "shape": (480, 640), "white_pixels": 298249},
|
||||
"score": None,
|
||||
"mask": {"hash": "a01498ca7c", "shape": (480, 640), "white_pixels": 307200},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"label": "LABEL_88",
|
||||
"mask": {"hash": "7f0bf661a4", "shape": (480, 640), "white_pixels": 3},
|
||||
"score": None,
|
||||
},
|
||||
{
|
||||
"label": "LABEL_101",
|
||||
"mask": {"hash": "10ab738dc9", "shape": (480, 640), "white_pixels": 8948},
|
||||
"score": None,
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": {"hash": "b431e0946c", "shape": (480, 640), "white_pixels": 298249},
|
||||
"score": None,
|
||||
"mask": {"hash": "a01498ca7c", "shape": (480, 640), "white_pixels": 307200},
|
||||
},
|
||||
]
|
||||
# [
|
||||
# {
|
||||
# "score": 0.004,
|
||||
# "label": "LABEL_215",
|
||||
# "mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
# },
|
||||
# {
|
||||
# "score": 0.004,
|
||||
# "label": "LABEL_215",
|
||||
# "mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
# },
|
||||
# ],
|
||||
# [
|
||||
# {
|
||||
# "score": 0.004,
|
||||
# "label": "LABEL_215",
|
||||
# "mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
# },
|
||||
# {
|
||||
# "score": 0.004,
|
||||
# "label": "LABEL_215",
|
||||
# "mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
# },
|
||||
# ],
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@ -311,7 +250,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
task="panoptic",
|
||||
subtask="panoptic",
|
||||
threshold=0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
@ -361,7 +300,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
task="panoptic",
|
||||
subtask="panoptic",
|
||||
threshold=0.0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
@ -448,7 +387,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", task="panoptic", threshold=0.999
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", subtask="panoptic", threshold=0.999
|
||||
)
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
@ -471,7 +410,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
)
|
||||
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", task="panoptic", threshold=0.5
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", subtask="panoptic", threshold=0.5
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
@ -521,7 +460,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
file = image[0]["file"]
|
||||
outputs = image_segmenter(file, task="panoptic", threshold=threshold)
|
||||
outputs = image_segmenter(file, subtask="panoptic", threshold=threshold)
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
|
Loading…
Reference in New Issue
Block a user