mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix image segmentation tool bug (#23897)
* Image segmentation tool bug * Remove resizing in the tests
This commit is contained in:
parent
6cd34d451c
commit
e6122c3f40
@ -44,7 +44,6 @@ class ImageSegmentationTool(PipelineTool):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", label: str):
|
||||
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
|
||||
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
|
@ -33,21 +33,21 @@ class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
self.remote_tool = load_tool("image-segmentation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
Loading…
Reference in New Issue
Block a user