mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Moving to cleaner tokenizer version or oneformer
. (#21292)
Moving to cleaner tokenizer version.
This commit is contained in:
parent
255257f3ea
commit
8788fd0ceb
@ -518,8 +518,8 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
reduce_labels=reduce_labels,
|
||||
)
|
||||
|
||||
def __call__(self, images, task_inputs, segmentation_maps=None, **kwargs) -> BatchFeature:
|
||||
return self.preprocess(images, task_inputs, segmentation_maps=segmentation_maps, **kwargs)
|
||||
def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature:
|
||||
return self.preprocess(images, task_inputs=task_inputs, segmentation_maps=segmentation_maps, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
@ -604,7 +604,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
task_inputs: List[str],
|
||||
task_inputs: Optional[List[str]] = None,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
@ -639,6 +639,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
do_reduce_labels = kwargs.pop("reduce_labels")
|
||||
|
||||
if task_inputs is None:
|
||||
# Default value
|
||||
task_inputs = ["panoptic"]
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
|
||||
@ -973,8 +977,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
classes, masks, texts = self.get_semantic_annotations(label, num_class_obj)
|
||||
elif task == "instance":
|
||||
classes, masks, texts = self.get_instance_annotations(label, num_class_obj)
|
||||
if task == "panoptic":
|
||||
elif task == "panoptic":
|
||||
classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj)
|
||||
else:
|
||||
raise ValueError(f"{task} was not expected, expected `semantic`, `instance` or `panoptic`")
|
||||
|
||||
# we cannot batch them since they don't share a common class size
|
||||
masks = [mask[None, ...] for mask in masks]
|
||||
@ -990,6 +996,9 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
encoded_inputs["class_labels"] = class_labels
|
||||
encoded_inputs["text_inputs"] = text_inputs
|
||||
|
||||
# This needs to be tokenized before sending to the model.
|
||||
encoded_inputs["task_inputs"] = [f"the task is {task_input}" for task_input in task_inputs]
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation
|
||||
|
@ -331,7 +331,7 @@ SUPPORTED_TASKS = {
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
|
||||
"type": "image",
|
||||
"type": "multimodal",
|
||||
},
|
||||
"image-to-text": {
|
||||
"impl": ImageToTextPipeline,
|
||||
|
@ -87,9 +87,11 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocessor_kwargs = {}
|
||||
postprocess_kwargs = {}
|
||||
if "subtask" in kwargs:
|
||||
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||
preprocessor_kwargs["subtask"] = kwargs["subtask"]
|
||||
if "threshold" in kwargs:
|
||||
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||
if "mask_threshold" in kwargs:
|
||||
@ -97,7 +99,7 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
if "overlap_mask_area_threshold" in kwargs:
|
||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||
|
||||
return {}, {}, postprocess_kwargs
|
||||
return preprocessor_kwargs, {}, postprocess_kwargs
|
||||
|
||||
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
@ -140,10 +142,23 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
def preprocess(self, image, subtask=None):
|
||||
image = load_image(image)
|
||||
target_size = [(image.height, image.width)]
|
||||
inputs = self.image_processor(images=[image], return_tensors="pt")
|
||||
if self.model.config.__class__.__name__ == "OneFormerConfig":
|
||||
if subtask is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"task_inputs": [subtask]}
|
||||
inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
|
||||
inputs["task_inputs"] = self.tokenizer(
|
||||
inputs["task_inputs"],
|
||||
padding="max_length",
|
||||
max_length=self.model.config.task_seq_len,
|
||||
return_tensors=self.framework,
|
||||
)["input_ids"]
|
||||
else:
|
||||
inputs = self.image_processor(images=[image], return_tensors="pt")
|
||||
inputs["target_size"] = target_size
|
||||
return inputs
|
||||
|
||||
|
@ -609,3 +609,105 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_oneformer(self):
|
||||
image_segmenter = pipeline(model="shi-labs/oneformer_ade20k_swin_tiny")
|
||||
|
||||
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
file = image[0]["file"]
|
||||
outputs = image_segmenter(file, threshold=0.99)
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{
|
||||
"score": 0.9981,
|
||||
"label": "grass",
|
||||
"mask": {"hash": "3a92904d4c", "white_pixels": 118131, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": 0.9992,
|
||||
"label": "sky",
|
||||
"mask": {"hash": "fa2300cc9a", "white_pixels": 231565, "shape": (512, 683)},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Different task
|
||||
outputs = image_segmenter(file, threshold=0.99, subtask="instance")
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{
|
||||
"score": 0.9991,
|
||||
"label": "sky",
|
||||
"mask": {"hash": "8b1ffad016", "white_pixels": 230566, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": 0.9981,
|
||||
"label": "grass",
|
||||
"mask": {"hash": "9bbdf83d3d", "white_pixels": 119130, "shape": (512, 683)},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Different task
|
||||
outputs = image_segmenter(file, subtask="semantic")
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{
|
||||
"score": None,
|
||||
"label": "wall",
|
||||
"mask": {"hash": "897fb20b7f", "white_pixels": 14506, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "building",
|
||||
"mask": {"hash": "f2a68c63e4", "white_pixels": 125019, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "sky",
|
||||
"mask": {"hash": "e0ca3a548e", "white_pixels": 135330, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "tree",
|
||||
"mask": {"hash": "7c9544bcac", "white_pixels": 16263, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "road, route",
|
||||
"mask": {"hash": "2c7704e491", "white_pixels": 2143, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "grass",
|
||||
"mask": {"hash": "bf6c2867e0", "white_pixels": 53040, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "plant",
|
||||
"mask": {"hash": "93c4b7199e", "white_pixels": 3335, "shape": (512, 683)},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "house",
|
||||
"mask": {"hash": "93ec419ad5", "white_pixels": 60, "shape": (512, 683)},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user