Processor accepts any kwargs (#31889)

* accept kwargs in processors

* return unused kwargs

* fix tests

* typo

* update the other way
This commit is contained in:
Raushan Turganbay 2024-07-11 13:20:30 +05:00 committed by GitHub
parent a695c18649
commit 14d3b3f0f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 65 additions and 22 deletions

View File

@ -39,10 +39,11 @@ class BlipProcessor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "BlipImageProcessor" image_processor_class = "BlipImageProcessor"
tokenizer_class = ("BertTokenizer", "BertTokenizerFast") tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
def __init__(self, image_processor, tokenizer): def __init__(self, image_processor, tokenizer, **kwargs):
tokenizer.return_token_type_ids = False tokenizer.return_token_type_ids = False
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor self.current_processor = self.image_processor

View File

@ -39,11 +39,12 @@ class Blip2Processor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "BlipImageProcessor" image_processor_class = "BlipImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
# Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__ # Copied from transformers.models.blip.processing_blip.BlipProcessor.__init__
def __init__(self, image_processor, tokenizer): def __init__(self, image_processor, tokenizer, **kwargs):
tokenizer.return_token_type_ids = False tokenizer.return_token_type_ids = False
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor self.current_processor = self.image_processor

View File

@ -322,10 +322,11 @@ class FuyuProcessor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "FuyuImageProcessor" image_processor_class = "FuyuImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer): def __init__(self, image_processor, tokenizer, **kwargs):
super().__init__(image_processor=image_processor, tokenizer=tokenizer) super().__init__(image_processor=image_processor, tokenizer=tokenizer)
self.image_processor = image_processor self.image_processor = image_processor
self.tokenizer = tokenizer self.tokenizer = tokenizer

View File

@ -173,6 +173,7 @@ class IdeficsProcessor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["image_size", "add_end_of_utterance_token"]
image_processor_class = "IdeficsImageProcessor" image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast" tokenizer_class = "LlamaTokenizerFast"

View File

@ -61,6 +61,7 @@ class Idefics2Processor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["image_seq_len", "chat_template"]
image_processor_class = "Idefics2ImageProcessor" image_processor_class = "Idefics2ImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"

View File

@ -40,15 +40,16 @@ class InstructBlipProcessor(ProcessorMixin):
An instance of [`BlipImageProcessor`]. The image processor is a required input. An instance of [`BlipImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`): tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
qformer_tokenizer (`AutoTokenizer`): qformer_tokenizer (`AutoTokenizer`, *optional*):
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input. An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "BlipImageProcessor" image_processor_class = "BlipImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, qformer_tokenizer): def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
# add QFormer tokenizer # add QFormer tokenizer
@ -167,7 +168,11 @@ class InstructBlipProcessor(ProcessorMixin):
# overwrite to load the Q-Former tokenizer from a separate folder # overwrite to load the Q-Former tokenizer from a separate folder
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer") qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) processor.qformer_tokenizer = qformer_tokenizer
args.append(qformer_tokenizer) return processor
return cls(*args)

View File

@ -40,15 +40,16 @@ class InstructBlipVideoProcessor(ProcessorMixin):
An instance of [`InstructBlipVideoImageProcessor`]. The image processor is a required input. An instance of [`InstructBlipVideoImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`): tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
qformer_tokenizer (`AutoTokenizer`): qformer_tokenizer (`AutoTokenizer`, *optional*):
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input. An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "InstructBlipVideoImageProcessor" image_processor_class = "InstructBlipVideoImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, qformer_tokenizer): def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
# add QFormer tokenizer # add QFormer tokenizer
@ -164,7 +165,11 @@ class InstructBlipVideoProcessor(ProcessorMixin):
# overwrite to load the Q-Former tokenizer from a separate folder # overwrite to load the Q-Former tokenizer from a separate folder
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer") qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) processor.qformer_tokenizer = qformer_tokenizer
args.append(qformer_tokenizer) return processor
return cls(*args)

View File

@ -54,10 +54,11 @@ class Kosmos2Processor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["num_patch_index_tokens"]
image_processor_class = "CLIPImageProcessor" image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024): def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs):
tokenizer.return_token_type_ids = False tokenizer.return_token_type_ids = False
self.eod_token = "</doc>" self.eod_token = "</doc>"

View File

@ -42,10 +42,11 @@ class LlavaProcessor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor" image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None): def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -42,10 +42,11 @@ class LlavaNextProcessor(ProcessorMixin):
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor" image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None): def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -53,11 +53,12 @@ class LlavaNextVideoProcessor(ProcessorMixin):
# video and image processor share same args, but have different processing logic # video and image processor share same args, but have different processing logic
# only image processor config is saved in the hub # only image processor config is saved in the hub
attributes = ["video_processor", "image_processor", "tokenizer"] attributes = ["video_processor", "image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "LlavaNextImageProcessor" image_processor_class = "LlavaNextImageProcessor"
video_processor_class = "LlavaNextVideoImageProcessor" video_processor_class = "LlavaNextVideoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
def __init__(self, video_processor=None, image_processor=None, tokenizer=None, chat_template=None): def __init__(self, video_processor=None, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template) super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -85,9 +85,12 @@ class PaliGemmaProcessor(ProcessorMixin):
The image processor is a required input. The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*): tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input. The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "SiglipImageProcessor" image_processor_class = "SiglipImageProcessor"
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
@ -95,6 +98,8 @@ class PaliGemmaProcessor(ProcessorMixin):
self, self,
image_processor=None, image_processor=None,
tokenizer=None, tokenizer=None,
chat_template=None,
**kwargs,
): ):
if image_processor is None: if image_processor is None:
raise ValueError("You need to specify an `image_processor`.") raise ValueError("You need to specify an `image_processor`.")
@ -113,7 +118,7 @@ class PaliGemmaProcessor(ProcessorMixin):
tokenizer.add_bos_token = False tokenizer.add_bos_token = False
tokenizer.add_eos_token = False tokenizer.add_eos_token = False
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(
self, self,

View File

@ -37,14 +37,17 @@ class VideoLlavaProcessor(ProcessorMixin):
The image processor is a required input. The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*): tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input. The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "VideoLlavaImageProcessor" image_processor_class = "VideoLlavaImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None): def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(
self, self,

View File

@ -320,6 +320,7 @@ class ProcessorMixin(PushToHubMixin):
feature_extractor_class = None feature_extractor_class = None
tokenizer_class = None tokenizer_class = None
_auto_class = None _auto_class = None
valid_kwargs: List[str] = []
# args have to match the attributes class attribute # args have to match the attributes class attribute
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -648,14 +649,15 @@ class ProcessorMixin(PushToHubMixin):
processor_dict = processor_dict.copy() processor_dict = processor_dict.copy()
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Unlike image processors or feature extractors whose `__init__` accept `kwargs`, processor don't have `kwargs`. # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
# We have to pop up some unused (but specific) arguments to make it work. # If we don't pop, some specific kwargs will raise a warning
if "processor_class" in processor_dict: if "processor_class" in processor_dict:
del processor_dict["processor_class"] del processor_dict["processor_class"]
if "auto_map" in processor_dict: if "auto_map" in processor_dict:
del processor_dict["auto_map"] del processor_dict["auto_map"]
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
processor = cls(*args, **processor_dict) processor = cls(*args, **processor_dict)
# Update processor with kwargs if needed # Update processor with kwargs if needed
@ -663,6 +665,7 @@ class ProcessorMixin(PushToHubMixin):
if hasattr(processor, key): if hasattr(processor, key):
setattr(processor, key, kwargs.pop(key)) setattr(processor, key, kwargs.pop(key))
kwargs.update(unused_kwargs)
logger.info(f"Processor {processor}") logger.info(f"Processor {processor}")
if return_unused_kwargs: if return_unused_kwargs:
return processor, kwargs return processor, kwargs
@ -887,6 +890,19 @@ class ProcessorMixin(PushToHubMixin):
first_attribute = getattr(self, self.attributes[0]) first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None) return getattr(first_attribute, "model_input_names", None)
@staticmethod
def validate_init_kwargs(processor_config, valid_kwargs):
kwargs_from_config = processor_config.keys()
unused_kwargs = {}
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
if unused_keys:
unused_key_str = ", ".join(unused_keys)
logger.warning(
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
)
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs
def apply_chat_template( def apply_chat_template(
self, self,
conversation: Union[List[Dict[str, str]]], conversation: Union[List[Dict[str, str]]],