Add ForInstanceSegmentation models to image-segmentation pipelines (#15937)

* Adding ForInstanceSegmentation to pipelines.

* Last fix `category_id` renamed to `label_id`.

* Can't be none no more.

* No `is_thing_map` anymore.
This commit is contained in:
Nicolas Patry 2022-03-09 10:19:05 +01:00 committed by GitHub
parent 5b7dcc7342
commit f4e4ad34cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 13 deletions

View File

@ -18,6 +18,7 @@ if is_torch_available():
from ..models.auto.modeling_auto import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
)
@ -32,10 +33,10 @@ Predictions = List[Prediction]
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ImageSegmentationPipeline(Pipeline):
"""
Image segmentation pipeline using any `AutoModelForImageSegmentation`. This pipeline predicts masks of objects and
Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
their classes.
This image segmntation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"image-segmentation"`.
See the list of available models on
@ -50,7 +51,11 @@ class ImageSegmentationPipeline(Pipeline):
requires_backends(self, "vision")
self.check_model_type(
dict(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items() + MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items())
dict(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items()
)
)
def _sanitize_parameters(self, **kwargs):
@ -112,14 +117,14 @@ class ImageSegmentationPipeline(Pipeline):
def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5):
if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"):
outputs = self.feature_extractor.post_process_panoptic_segmentation(
model_outputs, is_thing_map=self.model.config.id2label
model_outputs, object_mask_threshold=threshold
)[0]
annotation = []
segmentation = outputs["segmentation"]
for segment in outputs["segments"]:
mask = (segmentation == segment["id"]) * 255
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
label = self.model.config.id2label[segment["category_id"]]
label = self.model.config.id2label[segment["label_id"]]
annotation.append({"mask": mask, "label": label, "score": None})
elif hasattr(self.feature_extractor, "post_process_segmentation"):
# Panoptic

View File

@ -20,11 +20,14 @@ from datasets import load_dataset
from transformers import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
AutoFeatureExtractor,
AutoModelForImageSegmentation,
AutoModelForInstanceSegmentation,
DetrForSegmentation,
ImageSegmentationPipeline,
MaskFormerForInstanceSegmentation,
is_vision_available,
pipeline,
)
@ -67,6 +70,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
list(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()) if MODEL_FOR_IMAGE_SEGMENTATION_MAPPING else []
)
+ (MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items() if MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING else [])
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
}
def get_test_pipeline(self, model, tokenizer, feature_extractor):
@ -80,7 +84,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
self.assertIsInstance(outputs, list)
n = len(outputs)
self.assertGreater(n, 1)
if isinstance(image_segmenter.model, (MaskFormerForInstanceSegmentation)):
# Instance segmentation (maskformer) have a slot for null class
# and can output nothing even with a low threshold
self.assertGreaterEqual(n, 0)
else:
self.assertGreaterEqual(n, 1)
# XXX: PIL.Image implements __eq__ which bypasses ANY, so we inverse the comparison
# to make it work
self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * n, outputs)
@ -119,7 +128,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
]
outputs = image_segmenter(batch, threshold=0.0, batch_size=batch_size)
self.assertEqual(len(batch), len(outputs))
self.assertEqual({"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}, outputs[0][0])
self.assertEqual(len(outputs[0]), n)
self.assertEqual(
[
@ -313,18 +321,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
@require_torch
@slow
def test_maskformer(self):
threshold = 0.999
threshold = 0.8
model_id = "facebook/maskformer-swin-base-ade"
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id)
feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_id)
model = AutoModelForInstanceSegmentation.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor)
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
outputs = image_segmenter(image[0]["file"], threshold=threshold)
file = image[0]["file"]
outputs = image_segmenter(file, threshold=threshold)
for o in outputs:
o["mask"] = hashimage(o["mask"])