mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fixing image-segmentation
tests. (#14223)
This commit is contained in:
parent
7396095af7
commit
323f28dce2
@ -126,13 +126,13 @@ class ImageSegmentationPipeline(Pipeline):
|
|||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
target_size = model_inputs.pop("target_size")
|
target_size = model_inputs.pop("target_size")
|
||||||
outputs = self.model(**model_inputs)
|
model_outputs = self.model(**model_inputs)
|
||||||
model_outputs = {"outputs": outputs, "target_size": target_size}
|
model_outputs["target_size"] = target_size
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5):
|
def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5):
|
||||||
raw_annotations = self.feature_extractor.post_process_segmentation(
|
raw_annotations = self.feature_extractor.post_process_segmentation(
|
||||||
model_outputs["outputs"], model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
|
model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
|
||||||
)
|
)
|
||||||
raw_annotation = raw_annotations[0]
|
raw_annotation = raw_annotations[0]
|
||||||
|
|
||||||
|
@ -51,13 +51,18 @@ else:
|
|||||||
@require_timm
|
@require_timm
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@unittest.skip("Skip while fixing segmentation pipeline tests")
|
|
||||||
class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
@require_datasets
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
|
||||||
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
|
return image_segmenter, [
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
@require_datasets
|
||||||
|
def run_pipeline_test(self, image_segmenter, examples):
|
||||||
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||||
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
|
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user