diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 2b24fd2aeed..923cbd72362 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -126,13 +126,13 @@ class ImageSegmentationPipeline(Pipeline): def _forward(self, model_inputs): target_size = model_inputs.pop("target_size") - outputs = self.model(**model_inputs) - model_outputs = {"outputs": outputs, "target_size": target_size} + model_outputs = self.model(**model_inputs) + model_outputs["target_size"] = target_size return model_outputs def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5): raw_annotations = self.feature_extractor.post_process_segmentation( - model_outputs["outputs"], model_outputs["target_size"], threshold=threshold, mask_threshold=0.5 + model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5 ) raw_annotation = raw_annotations[0] diff --git a/tests/test_pipelines_image_segmentation.py b/tests/test_pipelines_image_segmentation.py index dc07e44a67a..ad4d456ba63 100644 --- a/tests/test_pipelines_image_segmentation.py +++ b/tests/test_pipelines_image_segmentation.py @@ -51,13 +51,18 @@ else: @require_timm @require_torch @is_pipeline_test -@unittest.skip("Skip while fixing segmentation pipeline tests") class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING - @require_datasets - def run_pipeline_test(self, model, tokenizer, feature_extractor): + def get_test_pipeline(self, model, tokenizer, feature_extractor): image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor) + return image_segmenter, [ + "./tests/fixtures/tests_samples/COCO/000000039769.png", + "./tests/fixtures/tests_samples/COCO/000000039769.png", + ] + + @require_datasets + def run_pipeline_test(self, image_segmenter, examples): outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)