diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index acf5c65403d..561f94e4e73 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -936,7 +936,6 @@ class AriaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "size_conversion"] image_processor_class = "AriaImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/aria/processing_aria.py b/src/transformers/models/aria/processing_aria.py index 6fe5ea3d599..7ecf3af670c 100644 --- a/src/transformers/models/aria/processing_aria.py +++ b/src/transformers/models/aria/processing_aria.py @@ -60,7 +60,6 @@ class AriaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "size_conversion"] image_processor_class = "AriaImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/aya_vision/processing_aya_vision.py b/src/transformers/models/aya_vision/processing_aya_vision.py index 9b68d4fcf51..be3f04a1819 100644 --- a/src/transformers/models/aya_vision/processing_aya_vision.py +++ b/src/transformers/models/aya_vision/processing_aya_vision.py @@ -18,17 +18,8 @@ from typing import List, Optional, Union import numpy as np from ...image_processing_utils import BatchFeature -from ...image_utils import ( - ImageInput, - make_flat_list_of_images, -) -from ...processing_utils import ( - ImagesKwargs, - MultiModalData, - ProcessingKwargs, - ProcessorMixin, - Unpack, -) +from ...image_utils import ImageInput, make_flat_list_of_images +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput @@ -87,19 +78,6 @@ class AyaVisionProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "image_token", - "patch_size", - "img_size", - "downsample_factor", - "start_of_img_token", - "end_of_img_token", - "img_patch_token", - "img_line_break_token", - "tile_token", - "tile_global_token", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py index c65ff6b66fd..5970e5edbb1 100644 --- a/src/transformers/models/blip/processing_blip.py +++ b/src/transformers/models/blip/processing_blip.py @@ -55,7 +55,6 @@ class BlipProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") tokenizer_class = ("BertTokenizer", "BertTokenizerFast") diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index 36b663dccb7..d94525f6b6f 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -21,12 +21,7 @@ from typing import List, Optional, Union from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - AddedToken, - BatchEncoding, - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import AddedToken, BatchEncoding, PreTokenizedInput, TextInput from ...utils import logging @@ -67,7 +62,6 @@ class Blip2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["num_query_tokens"] image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 8f55eeb26a4..5a364cdc34d 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -72,7 +72,6 @@ class ChameleonProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") - valid_kwargs = ["image_seq_length", "image_token"] image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py index 03b76bf9595..f34681c1d4f 100644 --- a/src/transformers/models/colpali/processing_colpali.py +++ b/src/transformers/models/colpali/processing_colpali.py @@ -90,7 +90,6 @@ class ColPaliProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") diff --git a/src/transformers/models/csm/processing_csm.py b/src/transformers/models/csm/processing_csm.py index ca516d82640..955f73cb363 100644 --- a/src/transformers/models/csm/processing_csm.py +++ b/src/transformers/models/csm/processing_csm.py @@ -31,10 +31,7 @@ if is_soundfile_available(): from ...audio_utils import AudioInput, make_list_of_audio from ...feature_extraction_utils import BatchFeature from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import PreTokenizedInput, TextInput class CsmAudioKwargs(AudioKwargs, total=False): @@ -99,7 +96,6 @@ class CsmProcessor(ProcessorMixin): """ attributes = ["feature_extractor", "tokenizer"] - valid_kwargs = ["chat_template"] feature_extractor_class = "EncodecFeatureExtractor" tokenizer_class = "PreTrainedTokenizerFast" diff --git a/src/transformers/models/emu3/processing_emu3.py b/src/transformers/models/emu3/processing_emu3.py index dd1928ac8a8..61b40217723 100644 --- a/src/transformers/models/emu3/processing_emu3.py +++ b/src/transformers/models/emu3/processing_emu3.py @@ -71,7 +71,6 @@ class Emu3Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") image_processor_class = "Emu3ImageProcessor" diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index 1b7fbab0e30..4852f3aaf9e 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -350,7 +350,6 @@ class FuyuProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] image_processor_class = "FuyuImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/gemma3/processing_gemma3.py b/src/transformers/models/gemma3/processing_gemma3.py index c8ac5e97146..ab6f03290a7 100644 --- a/src/transformers/models/gemma3/processing_gemma3.py +++ b/src/transformers/models/gemma3/processing_gemma3.py @@ -51,7 +51,6 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): class Gemma3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "image_seq_length"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/got_ocr2/processing_got_ocr2.py b/src/transformers/models/got_ocr2/processing_got_ocr2.py index 5e40d14dee8..b712245a64c 100644 --- a/src/transformers/models/got_ocr2/processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/processing_got_ocr2.py @@ -95,7 +95,6 @@ class GotOcr2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] image_processor_class = "AutoImageProcessor" tokenizer_class = "PreTrainedTokenizerFast" diff --git a/src/transformers/models/granite_speech/processing_granite_speech.py b/src/transformers/models/granite_speech/processing_granite_speech.py index ec36eb49703..9032601a6b2 100644 --- a/src/transformers/models/granite_speech/processing_granite_speech.py +++ b/src/transformers/models/granite_speech/processing_granite_speech.py @@ -31,8 +31,6 @@ logger = logging.get_logger(__name__) class GraniteSpeechProcessor(ProcessorMixin): attributes = ["audio_processor", "tokenizer"] - valid_kwargs = ["audio_token"] - audio_processor_class = "GraniteSpeechFeatureExtractor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 37876080dfc..e226e15da19 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -211,7 +211,6 @@ class IdeficsProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["image_size", "add_end_of_utterance_token"] image_processor_class = "IdeficsImageProcessor" tokenizer_class = "LlamaTokenizerFast" diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index ab144f3f9de..5be15d8cd8b 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -85,7 +85,6 @@ class Idefics2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["image_seq_len", "chat_template"] image_processor_class = "Idefics2ImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index 5d058ec8245..5f4450df8b4 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -133,7 +133,6 @@ class Idefics3Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["image_seq_len", "chat_template"] image_processor_class = "Idefics3ImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index 408dfbd0756..d3df6f4ef90 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -22,12 +22,7 @@ from typing import List, Union from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - AddedToken, - BatchEncoding, - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import AddedToken, BatchEncoding, PreTokenizedInput, TextInput from ...utils import logging from ..auto import AutoTokenizer @@ -72,7 +67,6 @@ class InstructBlipProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "qformer_tokenizer"] - valid_kwargs = ["num_query_tokens"] image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 8c59606e4b6..fad69b72e2f 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -57,7 +57,6 @@ class InstructBlipVideoProcessor(ProcessorMixin): """ attributes = ["video_processor", "tokenizer", "qformer_tokenizer"] - valid_kwargs = ["num_query_tokens"] video_processor_class = "AutoVideoProcessor" tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/internvl/processing_internvl.py b/src/transformers/models/internvl/processing_internvl.py index addb45a688a..c9a8c2028d5 100644 --- a/src/transformers/models/internvl/processing_internvl.py +++ b/src/transformers/models/internvl/processing_internvl.py @@ -18,18 +18,8 @@ from typing import List, Optional, Union import numpy as np from ...image_processing_utils import BatchFeature -from ...image_utils import ( - ImageInput, - concatenate_list, - make_flat_list_of_images, -) -from ...processing_utils import ( - ImagesKwargs, - MultiModalData, - ProcessingKwargs, - ProcessorMixin, - Unpack, -) +from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos @@ -74,10 +64,6 @@ class InternVLProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "video_processor"] - valid_kwargs = [ - "chat_template", - "image_seq_length", - ] image_processor_class = "AutoImageProcessor" video_processor_class = "AutoVideoProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/janus/processing_janus.py b/src/transformers/models/janus/processing_janus.py index 4132ca8f43d..d5f626a24c7 100644 --- a/src/transformers/models/janus/processing_janus.py +++ b/src/transformers/models/janus/processing_janus.py @@ -21,10 +21,7 @@ from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack -from ...tokenization_utils_base import ( - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging @@ -68,7 +65,6 @@ class JanusProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "use_default_system_prompt"] image_processor_class = "JanusImageProcessor" tokenizer_class = "LlamaTokenizerFast" diff --git a/src/transformers/models/kosmos2/processing_kosmos2.py b/src/transformers/models/kosmos2/processing_kosmos2.py index 73a3f66f9b5..3a1c9253824 100644 --- a/src/transformers/models/kosmos2/processing_kosmos2.py +++ b/src/transformers/models/kosmos2/processing_kosmos2.py @@ -84,7 +84,6 @@ class Kosmos2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["num_patch_index_tokens"] image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast") tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/llama4/processing_llama4.py b/src/transformers/models/llama4/processing_llama4.py index 7ca562571cb..a020826aade 100644 --- a/src/transformers/models/llama4/processing_llama4.py +++ b/src/transformers/models/llama4/processing_llama4.py @@ -16,19 +16,11 @@ from typing import List, Optional, Union -from transformers.processing_utils import ( - ImagesKwargs, - ProcessingKwargs, - ProcessorMixin, - Unpack, -) +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from ...image_processing_utils import BatchFeature -from ...image_utils import ( - ImageInput, - make_flat_list_of_images, -) +from ...image_utils import ImageInput, make_flat_list_of_images class Llama4ImagesKwargs(ImagesKwargs, total=False): @@ -83,19 +75,6 @@ class Llama4Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "image_token", - "patch_size", - "img_size", - "downsample_factor", - "start_of_img_token", - "end_of_img_token", - "img_patch_token", - "img_line_break_token", - "tile_token", - "tile_global_token", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index d0ad3d0af17..b345df4d23b 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -70,13 +70,6 @@ class LlavaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "patch_size", - "vision_feature_select_strategy", - "image_token", - "num_additional_image_tokens", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 7a3eb43b7be..5a25f4072c2 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -76,13 +76,6 @@ class LlavaNextProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "patch_size", - "vision_feature_select_strategy", - "image_token", - "num_additional_image_tokens", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index a3619c616be..b4cd7f7c221 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -78,14 +78,6 @@ class LlavaNextVideoProcessor(ProcessorMixin): # video and image processor share same args, but have different processing logic # only image processor config is saved in the hub attributes = ["video_processor", "image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "patch_size", - "vision_feature_select_strategy", - "image_token", - "video_token", - "num_additional_image_tokens", - ] image_processor_class = ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast") video_processor_class = "AutoVideoProcessor" tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index 224b6e37cca..00cf4b579eb 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -75,14 +75,6 @@ class LlavaOnevisionProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "video_processor"] - valid_kwargs = [ - "chat_template", - "num_image_tokens", - "vision_feature_select_strategy", - "image_token", - "video_token", - "vision_aspect_ratio", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" video_processor_class = "AutoVideoProcessor" diff --git a/src/transformers/models/mllama/processing_mllama.py b/src/transformers/models/mllama/processing_mllama.py index 465d7c49544..ffe76709482 100644 --- a/src/transformers/models/mllama/processing_mllama.py +++ b/src/transformers/models/mllama/processing_mllama.py @@ -22,10 +22,7 @@ import numpy as np from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import PreTokenizedInput, TextInput class MllamaImagesKwargs(ImagesKwargs, total=False): @@ -208,7 +205,6 @@ class MllamaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] image_processor_class = "MllamaImageProcessor" tokenizer_class = "PreTrainedTokenizerFast" diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index a630c4720ed..1440ea7b66f 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -31,11 +31,7 @@ from ...processing_utils import ( Unpack, _validate_images_text_input_order, ) -from ...tokenization_utils_base import ( - AddedToken, - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput from ...utils import logging @@ -120,7 +116,6 @@ class PaliGemmaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") diff --git a/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py index a0d5a75a59c..ebccdc04608 100644 --- a/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py @@ -62,7 +62,6 @@ class Phi4MultimodalProcessor(ProcessorMixin): tokenizer_class = "GPT2TokenizerFast" image_processor_class = "Phi4MultimodalImageProcessorFast" audio_processor_class = "Phi4MultimodalFeatureExtractor" - valid_kwargs = ["chat_template"] def __init__( self, diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index aa5681f28b3..8a15fa8e1e5 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -90,14 +90,6 @@ class PixtralProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "patch_size", - "spatial_merge_size", - "image_token", - "image_break_token", - "image_end_token", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py index 28b4f38e485..ea449184705 100644 --- a/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py @@ -97,7 +97,6 @@ class Qwen2_5OmniProcessor(ProcessorMixin): video_processor_class = "Qwen2VLVideoProcessor" feature_extractor_class = "WhisperFeatureExtractor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") - valid_kwargs = ["chat_template"] def __init__( self, image_processor=None, video_processor=None, feature_extractor=None, tokenizer=None, chat_template=None diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 8b05b725bf9..f835390a079 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -75,7 +75,6 @@ class Qwen2_5_VLProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "video_processor"] - valid_kwargs = ["chat_template"] image_processor_class = "AutoImageProcessor" video_processor_class = "AutoVideoProcessor" diff --git a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py index 1d783ac26c3..1eac9e8b7c3 100644 --- a/src/transformers/models/qwen2_audio/processing_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/processing_qwen2_audio.py @@ -60,7 +60,6 @@ class Qwen2AudioProcessor(ProcessorMixin): """ attributes = ["feature_extractor", "tokenizer"] - valid_kwargs = ["chat_template", "audio_token", "audio_bos_token", "audio_eos_token"] feature_extractor_class = "WhisperFeatureExtractor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index dd4cafa132e..6cd056aa1d5 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -71,7 +71,6 @@ class Qwen2VLProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "video_processor"] - valid_kwargs = ["chat_template"] image_processor_class = "AutoImageProcessor" video_processor_class = "AutoVideoProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") diff --git a/src/transformers/models/smolvlm/processing_smolvlm.py b/src/transformers/models/smolvlm/processing_smolvlm.py index b858cc68a4e..a440a0f29b1 100644 --- a/src/transformers/models/smolvlm/processing_smolvlm.py +++ b/src/transformers/models/smolvlm/processing_smolvlm.py @@ -139,7 +139,6 @@ class SmolVLMProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "video_processor"] - valid_kwargs = ["image_seq_len", "chat_template"] image_processor_class = "SmolVLMImageProcessor" video_processor_class = ( "SmolVLMImageProcessor" # TODO: raushan should be VideoProcessor when LANCZOS resizing is settled diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index d83af3a9f65..7a6edb8cff2 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -61,14 +61,6 @@ class VideoLlavaProcessor(ProcessorMixin): """ attributes = ["image_processor", "video_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "patch_size", - "vision_feature_select_strategy", - "image_token", - "video_token", - "num_additional_image_tokens", - ] image_processor_class = "VideoLlavaImageProcessor" video_processor_class = "AutoVideoProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index add9b4e2ad2..8ee7ce5adbb 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -497,7 +497,6 @@ class ProcessorMixin(PushToHubMixin): feature_extractor_class = None tokenizer_class = None _auto_class = None - valid_kwargs: list[str] = [] # args have to match the attributes class attribute def __init__(self, *args, **kwargs): @@ -996,18 +995,27 @@ class ProcessorMixin(PushToHubMixin): if "auto_map" in processor_dict: del processor_dict["auto_map"] - unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs) - processor = cls(*args, **processor_dict) + # override processor_dict with given kwargs + processor_dict.update(kwargs) - # Update processor with kwargs if needed - for key in set(kwargs.keys()): - if hasattr(processor, key): - setattr(processor, key, kwargs.pop(key)) + # check if there is an overlap between args and processor_dict + accepted_args_and_kwargs = cls.__init__.__code__.co_varnames[: cls.__init__.__code__.co_argcount][1:] + + # validate both processor_dict and given kwargs + unused_kwargs, valid_kwargs = cls.validate_init_kwargs( + processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs + ) + + # remove args that are in processor_dict to avoid duplicate arguments + args_to_remove = [i for i, arg in enumerate(accepted_args_and_kwargs) if arg in processor_dict] + args = [arg for i, arg in enumerate(args) if i not in args_to_remove] + + # instantiate processor with used (and valid) kwargs only + processor = cls(*args, **valid_kwargs) - kwargs.update(unused_kwargs) logger.info(f"Processor {processor}") if return_unused_kwargs: - return processor, kwargs + return processor, unused_kwargs else: return processor @@ -1294,12 +1302,16 @@ class ProcessorMixin(PushToHubMixin): @staticmethod def validate_init_kwargs(processor_config, valid_kwargs): - kwargs_from_config = processor_config.keys() - unused_kwargs = {} - unused_keys = set(kwargs_from_config) - set(valid_kwargs) - if unused_keys: - unused_kwargs = {k: processor_config[k] for k in unused_keys} - return unused_kwargs + kwargs_from_config = set(processor_config.keys()) + valid_kwargs_set = set(valid_kwargs) + + unused_keys = kwargs_from_config - valid_kwargs_set + valid_keys = kwargs_from_config & valid_kwargs_set + + unused_kwargs = {k: processor_config[k] for k in unused_keys} if unused_keys else {} + valid_kwargs = {k: processor_config[k] for k in valid_keys} if valid_keys else {} + + return unused_kwargs, valid_kwargs def prepare_and_validate_optional_call_args(self, *args): """