mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add ForInstanceSegmentation
models to image-segmentation
pipelines (#15937)
* Adding ForInstanceSegmentation to pipelines. * Last fix `category_id` renamed to `label_id`. * Can't be none no more. * No `is_thing_map` anymore.
This commit is contained in:
parent
5b7dcc7342
commit
f4e4ad34cc
@ -18,6 +18,7 @@ if is_torch_available():
|
||||
|
||||
from ..models.auto.modeling_auto import (
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
|
||||
)
|
||||
|
||||
@ -32,10 +33,10 @@ Predictions = List[Prediction]
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ImageSegmentationPipeline(Pipeline):
|
||||
"""
|
||||
Image segmentation pipeline using any `AutoModelForImageSegmentation`. This pipeline predicts masks of objects and
|
||||
Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
|
||||
their classes.
|
||||
|
||||
This image segmntation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"image-segmentation"`.
|
||||
|
||||
See the list of available models on
|
||||
@ -50,7 +51,11 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(
|
||||
dict(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items() + MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items())
|
||||
dict(
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()
|
||||
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()
|
||||
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items()
|
||||
)
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
@ -112,14 +117,14 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
def postprocess(self, model_outputs, raw_image=False, threshold=0.9, mask_threshold=0.5):
|
||||
if hasattr(self.feature_extractor, "post_process_panoptic_segmentation"):
|
||||
outputs = self.feature_extractor.post_process_panoptic_segmentation(
|
||||
model_outputs, is_thing_map=self.model.config.id2label
|
||||
model_outputs, object_mask_threshold=threshold
|
||||
)[0]
|
||||
annotation = []
|
||||
segmentation = outputs["segmentation"]
|
||||
for segment in outputs["segments"]:
|
||||
mask = (segmentation == segment["id"]) * 255
|
||||
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
|
||||
label = self.model.config.id2label[segment["category_id"]]
|
||||
label = self.model.config.id2label[segment["label_id"]]
|
||||
annotation.append({"mask": mask, "label": label, "score": None})
|
||||
elif hasattr(self.feature_extractor, "post_process_segmentation"):
|
||||
# Panoptic
|
||||
|
@ -20,11 +20,14 @@ from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForInstanceSegmentation,
|
||||
DetrForSegmentation,
|
||||
ImageSegmentationPipeline,
|
||||
MaskFormerForInstanceSegmentation,
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
@ -67,6 +70,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
list(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()) if MODEL_FOR_IMAGE_SEGMENTATION_MAPPING else []
|
||||
)
|
||||
+ (MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items() if MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING else [])
|
||||
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
@ -80,7 +84,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||
self.assertIsInstance(outputs, list)
|
||||
n = len(outputs)
|
||||
self.assertGreater(n, 1)
|
||||
if isinstance(image_segmenter.model, (MaskFormerForInstanceSegmentation)):
|
||||
# Instance segmentation (maskformer) have a slot for null class
|
||||
# and can output nothing even with a low threshold
|
||||
self.assertGreaterEqual(n, 0)
|
||||
else:
|
||||
self.assertGreaterEqual(n, 1)
|
||||
# XXX: PIL.Image implements __eq__ which bypasses ANY, so we inverse the comparison
|
||||
# to make it work
|
||||
self.assertEqual([{"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}] * n, outputs)
|
||||
@ -119,7 +128,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
]
|
||||
outputs = image_segmenter(batch, threshold=0.0, batch_size=batch_size)
|
||||
self.assertEqual(len(batch), len(outputs))
|
||||
self.assertEqual({"score": ANY(float, type(None)), "label": ANY(str), "mask": ANY(Image.Image)}, outputs[0][0])
|
||||
self.assertEqual(len(outputs[0]), n)
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -313,18 +321,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
@require_torch
|
||||
@slow
|
||||
def test_maskformer(self):
|
||||
threshold = 0.999
|
||||
threshold = 0.8
|
||||
model_id = "facebook/maskformer-swin-base-ade"
|
||||
|
||||
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
|
||||
|
||||
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id)
|
||||
feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_id)
|
||||
model = AutoModelForInstanceSegmentation.from_pretrained(model_id)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=feature_extractor)
|
||||
|
||||
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
outputs = image_segmenter(image[0]["file"], threshold=threshold)
|
||||
file = image[0]["file"]
|
||||
outputs = image_segmenter(file, threshold=threshold)
|
||||
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
|
Loading…
Reference in New Issue
Block a user