add molbap's commit

This commit is contained in:
sangbumchoi 2024-07-24 01:47:47 +00:00
parent bb1f18bb3b
commit a476c6ee88
3 changed files with 225 additions and 69 deletions

View File

@ -16,13 +16,14 @@
Processor class for Grounding DINO.
"""
import pathlib
import sys
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
if sys.version_info >= (3, 11):
@ -31,12 +32,19 @@ else:
from typing_extensions import Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available
from ...utils import ExplicitEnum, TensorType, is_torch_available
if is_torch_available():
import torch
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
class AnnotationFormat(ExplicitEnum):
COCO_DETECTION = "coco_detection"
COCO_PANOPTIC = "coco_panoptic"
def get_phrases_from_posmap(posmaps, input_ids):
"""Get token ids of phrases from posmaps and input_ids.
@ -64,7 +72,16 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids
class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
do_convert_annotations: Optional[bool]
format: Optional[Union[str, AnnotationFormat]]
class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: GroundingDinoImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,

View File

@ -150,6 +150,8 @@ class ImagesKwargs(TypedDict, total=False):
Standard deviation to use if normalizing the image.
do_pad (`bool`, *optional*):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to.
do_center_crop (`bool`, *optional*):
Whether to center crop the image.
data_format (`ChannelDimension` or `str`, *optional*):
@ -169,6 +171,7 @@ class ImagesKwargs(TypedDict, total=False):
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
@ -320,7 +323,6 @@ class ProcessorMixin(PushToHubMixin):
feature_extractor_class = None
tokenizer_class = None
_auto_class = None
valid_kwargs: List[str] = []
# args have to match the attributes class attribute
def __init__(self, *args, **kwargs):
@ -649,15 +651,14 @@ class ProcessorMixin(PushToHubMixin):
processor_dict = processor_dict.copy()
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
# If we don't pop, some specific kwargs will raise a warning
# Unlike image processors or feature extractors whose `__init__` accept `kwargs`, processor don't have `kwargs`.
# We have to pop up some unused (but specific) arguments to make it work.
if "processor_class" in processor_dict:
del processor_dict["processor_class"]
if "auto_map" in processor_dict:
del processor_dict["auto_map"]
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
processor = cls(*args, **processor_dict)
# Update processor with kwargs if needed
@ -665,7 +666,6 @@ class ProcessorMixin(PushToHubMixin):
if hasattr(processor, key):
setattr(processor, key, kwargs.pop(key))
kwargs.update(unused_kwargs)
logger.info(f"Processor {processor}")
if return_unused_kwargs:
return processor, kwargs
@ -743,38 +743,34 @@ class ProcessorMixin(PushToHubMixin):
if modality_key in tokenizer_init_kwargs:
default_kwargs[modality][modality_key] = tokenizer_init_kwargs[modality_key]
# now defaults kwargs are updated with the tokenizers defaults.
# pass defaults to output dictionary
output_kwargs.update(default_kwargs)
# gather common kwargs and remove them from individual kwargs if present
common_kwargs = {
key: value
for key, value in kwargs.items()
if key not in ModelProcessorKwargs.__annotations__["text_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["images_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["audio_kwargs"].__annotations__
and key not in ModelProcessorKwargs.__annotations__["videos_kwargs"].__annotations__
}
# ensure common kwargs are propagated to all relevant modalities
for key, value in common_kwargs.items():
for modality in output_kwargs:
if modality != "common_kwargs":
output_kwargs[modality][key] = value
# remove common kwargs from the kwargs to process the rest
kwargs = {k: v for k, v in kwargs.items() if k not in common_kwargs}
# update modality kwargs with passed kwargs
non_modality_kwargs = set(kwargs) - set(output_kwargs)
for modality in output_kwargs:
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
# check if we received a structured kwarg dict or not to handle it correctly
if modality in kwargs:
kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
# check if this key was passed as a flat kwarg.
if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
raise ValueError(
f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg."
)
if modality in kwargs and modality_key in kwargs[modality]:
output_kwargs[modality][modality_key] = kwargs[modality][modality_key]
elif modality_key in kwargs:
kwarg_value = kwargs.pop(modality_key, "__empty__")
else:
kwarg_value = "__empty__"
if kwarg_value != "__empty__":
output_kwargs[modality][modality_key] = kwarg_value
# if something remains in kwargs, it belongs to common after flattening
if set(kwargs) & set(default_kwargs):
# here kwargs is dictionary-based since it shares keys with default set
[output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()]
else:
# here it's a flat dict
output_kwargs["common_kwargs"].update(kwargs)
# all modality-specific kwargs are updated with common kwargs
for modality in output_kwargs:
output_kwargs[modality].update(output_kwargs["common_kwargs"])
output_kwargs[modality][modality_key] = kwargs[modality_key]
return output_kwargs
@classmethod
@ -890,19 +886,6 @@ class ProcessorMixin(PushToHubMixin):
first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None)
@staticmethod
def validate_init_kwargs(processor_config, valid_kwargs):
kwargs_from_config = processor_config.keys()
unused_kwargs = {}
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
if unused_keys:
unused_key_str = ", ".join(unused_keys)
logger.warning(
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
)
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],

View File

@ -16,6 +16,7 @@
import inspect
import json
import random
import tempfile
@ -38,15 +39,31 @@ from transformers.testing_utils import (
from transformers.utils import is_vision_available
global_rng = random.Random()
if is_vision_available():
from PIL import Image
from transformers import CLIPImageProcessor
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_vision
@require_torch
class ProcessorTesterMixin:
processor_class = None
@ -60,7 +77,10 @@ class ProcessorTesterMixin:
component_class_name = component_class_name[0]
component_class = processor_class_from_name(component_class_name)
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
if hasattr(self, "tmpdirname"):
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
elif hasattr(self, "model_id"):
component = component_class.from_pretrained(self.model_id, **kwargs) # noqa
return component
@ -126,13 +146,13 @@ class ProcessorTesterMixin:
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
@ -141,15 +161,15 @@ class ProcessorTesterMixin:
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size=(234, 234))
image_processor = self.get_component("image_processor", crop_size=(234, 234), size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
@ -160,13 +180,15 @@ class ProcessorTesterMixin:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=112)
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
@require_torch
@ -174,16 +196,17 @@ class ProcessorTesterMixin:
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size=(234, 234))
image_processor = self.get_component("image_processor", crop_size=(234, 234), size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, crop_size=[224, 224])
inputs = processor(text=input_str, images=image_input, crop_size=[224, 224], size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
@require_torch
@ -193,7 +216,8 @@ class ProcessorTesterMixin:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
@ -204,6 +228,7 @@ class ProcessorTesterMixin:
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)
@ -218,7 +243,8 @@ class ProcessorTesterMixin:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
@ -229,10 +255,10 @@ class ProcessorTesterMixin:
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 6)
@ -244,7 +270,8 @@ class ProcessorTesterMixin:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
@ -265,7 +292,8 @@ class ProcessorTesterMixin:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
@ -275,7 +303,7 @@ class ProcessorTesterMixin:
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}, "size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
@ -294,7 +322,8 @@ class ProcessorTesterMixin:
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
@ -303,7 +332,7 @@ class ProcessorTesterMixin:
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}, "size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
@ -312,6 +341,133 @@ class ProcessorTesterMixin:
self.assertEqual(len(inputs["input_ids"][0]), 76)
# text + audio kwargs testing
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117, padding="max_length")
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
raw_speech = floats_list((3, 1000))
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 117)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 117)
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
raw_speech = floats_list((3, 1000))
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 112)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 112)
@require_torch
def test_unstructured_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
raw_speech = floats_list((3, 1000))
inputs = processor(
text=input_str,
audio=raw_speech,
return_tensors="pt",
padding="max_length",
max_length=76,
)
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
@require_torch
def test_doubly_passed_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
raw_speech = floats_list((3, 1000))
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
audio=raw_speech,
audio_kwargs={"padding": "max_length"},
padding="max_length",
)
@require_torch
@require_vision
def test_structured_kwargs_audio_nested(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
raw_speech = floats_list((3, 1000))
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"audio_kwargs": {"padding": "max_length", "max_length": 66},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
class MyProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]