diff --git a/src/transformers/models/eomt/image_processing_eomt.py b/src/transformers/models/eomt/image_processing_eomt.py index 7c2e1bb7150..bf7a04cb3c0 100644 --- a/src/transformers/models/eomt/image_processing_eomt.py +++ b/src/transformers/models/eomt/image_processing_eomt.py @@ -858,6 +858,7 @@ class EomtImageProcessor(BaseImageProcessor): ): """Post-processes model outputs into final panoptic segmentation prediction.""" + # `mask_threshold` and `overlap_mask_area_threshold` args are unused and only present for Pipeline compatability. size = size if size is not None else self.size masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] diff --git a/src/transformers/models/eomt/image_processing_eomt_fast.py b/src/transformers/models/eomt/image_processing_eomt_fast.py index ca28521dcef..5dd2a130541 100644 --- a/src/transformers/models/eomt/image_processing_eomt_fast.py +++ b/src/transformers/models/eomt/image_processing_eomt_fast.py @@ -522,6 +522,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast): ): """Post-processes model outputs into Instance Segmentation Predictions.""" + # `mask_threshold` and `overlap_mask_area_threshold` args are unused and only present for Pipeline compatability. size = size if size is not None else self.size masks_queries_logits = outputs.masks_queries_logits diff --git a/tests/models/eomt/test_modeling_eomt.py b/tests/models/eomt/test_modeling_eomt.py index d9a090ac851..968d6188057 100644 --- a/tests/models/eomt/test_modeling_eomt.py +++ b/tests/models/eomt/test_modeling_eomt.py @@ -474,4 +474,9 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase): image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) pipe = pipeline(model=self.model_id, subtask="panoptic", device=torch_device) - _ = pipe(image) + output = pipe(image) + + EXPECTED_OUTPUT_LABELS = ["LABEL_15", "LABEL_15", "LABEL_57", "LABEL_65", "LABEL_65"] + + output_labels = [segment["label"] for segment in output["segments_info"]] + self.assertEqual(output_labels, EXPECTED_OUTPUT_LABELS)