From 3bc44eaaeee01b7f0d2d55c9991900b43cafe62d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 17 Apr 2025 09:38:12 +0200 Subject: [PATCH] [qwen-vl] Standardize config (#37268) * update * fix tests * fixup * update * skip this one * fixup * fix --- docs/source/en/model_doc/qwen2_5_vl.md | 5 ++ docs/source/en/model_doc/qwen2_vl.md | 4 + .../models/auto/configuration_auto.py | 6 ++ src/transformers/models/auto/modeling_auto.py | 2 + .../qwen2_5_omni/modeling_qwen2_5_omni.py | 2 +- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 86 +++++++++++++++---- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 22 +++-- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 11 ++- .../models/qwen2_vl/configuration_qwen2_vl.py | 86 +++++++++++++++---- .../models/qwen2_vl/modeling_qwen2_vl.py | 23 ++--- .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 4 + .../models/qwen2_vl/test_modeling_qwen2_vl.py | 4 + utils/check_config_attributes.py | 2 + 13 files changed, 202 insertions(+), 55 deletions(-) diff --git a/docs/source/en/model_doc/qwen2_5_vl.md b/docs/source/en/model_doc/qwen2_5_vl.md index 2d38fe82e61..2fb2eadc53e 100644 --- a/docs/source/en/model_doc/qwen2_5_vl.md +++ b/docs/source/en/model_doc/qwen2_5_vl.md @@ -232,10 +232,15 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained( [[autodoc]] Qwen2_5_VLConfig +## Qwen2_5_VLTextConfig + +[[autodoc]] Qwen2_5_VLTextConfig + ## Qwen2_5_VLProcessor [[autodoc]] Qwen2_5_VLProcessor + ## Qwen2_5_VLModel [[autodoc]] Qwen2_5_VLModel diff --git a/docs/source/en/model_doc/qwen2_vl.md b/docs/source/en/model_doc/qwen2_vl.md index 37c7ad31b31..3d1845b6015 100644 --- a/docs/source/en/model_doc/qwen2_vl.md +++ b/docs/source/en/model_doc/qwen2_vl.md @@ -278,6 +278,10 @@ model = Qwen2VLForConditionalGeneration.from_pretrained( [[autodoc]] Qwen2VLConfig +## Qwen2VLTextConfig + +[[autodoc]] Qwen2VLTextConfig + ## Qwen2VLImageProcessor [[autodoc]] Qwen2VLImageProcessor diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 73192931338..a2a69e923cb 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -258,10 +258,12 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("qwen2", "Qwen2Config"), ("qwen2_5_omni", "Qwen2_5OmniConfig"), ("qwen2_5_vl", "Qwen2_5_VLConfig"), + ("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"), ("qwen2_audio", "Qwen2AudioConfig"), ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), ("qwen2_moe", "Qwen2MoeConfig"), ("qwen2_vl", "Qwen2VLConfig"), + ("qwen2_vl_text", "Qwen2VLTextConfig"), ("qwen3", "Qwen3Config"), ("qwen3_moe", "Qwen3MoeConfig"), ("rag", "RagConfig"), @@ -625,10 +627,12 @@ MODEL_NAMES_MAPPING = OrderedDict( ("qwen2", "Qwen2"), ("qwen2_5_omni", "Qwen2_5Omni"), ("qwen2_5_vl", "Qwen2_5_VL"), + ("qwen2_5_vl_text", "Qwen2_5_VL"), ("qwen2_audio", "Qwen2Audio"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoE"), ("qwen2_vl", "Qwen2VL"), + ("qwen2_vl_text", "Qwen2VL"), ("qwen3", "Qwen3"), ("qwen3_moe", "Qwen3MoE"), ("rag", "RAG"), @@ -793,6 +797,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), ("granitevision", "llava_next"), + ("qwen2_5_vl_text", "qwen2_5_vl"), + ("qwen2_vl_text", "qwen2_vl"), ("sam_vision_model", "sam"), ("llama4_text", "llama4"), ("blip_2_qformer", "blip_2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1746d97fd0f..7982ab8d981 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -234,9 +234,11 @@ MODEL_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertModel"), ("qwen2", "Qwen2Model"), ("qwen2_5_vl", "Qwen2_5_VLModel"), + ("qwen2_5_vl_text", "Qwen2_5_VLModel"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), + ("qwen2_vl_text", "Qwen2VLModel"), ("qwen3", "Qwen3Model"), ("qwen3_moe", "Qwen3MoeModel"), ("recurrent_gemma", "RecurrentGemmaModel"), diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 718b719e58a..2898354399b 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1792,7 +1792,7 @@ QWEN2_5_OMNI_ATTENTION_CLASSES = { class Qwen2_5OmniDecoderLayer(nn.Module): - def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int): + def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 63ca1c23592..588ad214ebd 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -67,9 +67,9 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): self.initializer_range = initializer_range -class Qwen2_5_VLConfig(PretrainedConfig): +class Qwen2_5_VLTextConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). @@ -77,7 +77,6 @@ class Qwen2_5_VLConfig(PretrainedConfig): Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: vocab_size (`int`, *optional*, defaults to 152064): Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the @@ -120,8 +119,6 @@ class Qwen2_5_VLConfig(PretrainedConfig): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - vision_config (`Dict`, *optional*): - The config for the visual encoder initialization. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value @@ -161,20 +158,20 @@ class Qwen2_5_VLConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE ```python - >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + >>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig >>> # Initializing a Qwen2_5_VL style configuration >>> configuration = Qwen2_5_VLConfig() >>> # Initializing a model from the Qwen2-VL-7B style configuration - >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + >>> model = Qwen2_5_VLTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" - model_type = "qwen2_5_vl" - sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + model_type = "qwen2_5_vl_text" + base_config_key = "text_config" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Qwen2_5_VL` base_model_tp_plan = { @@ -211,15 +208,9 @@ class Qwen2_5_VLConfig(PretrainedConfig): sliding_window=4096, max_window_layers=80, attention_dropout=0.0, - vision_config=None, rope_scaling=None, **kwargs, ): - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -257,4 +248,67 @@ class Qwen2_5_VLConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -__all__ = ["Qwen2_5_VLConfig"] +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + # For BC use all kwargs to init `TextConfig` + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + super().__init__(**kwargs) + + +__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"] diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 388b5f90555..8155a0d280b 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -48,7 +48,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig if is_flash_attn_available(): @@ -390,7 +390,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` def _init_weights(self, module): - std = self.config.initializer_range + std = self.config.get_text_config().initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -566,7 +566,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): class Qwen2_5_VLRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen2_5_VLConfig, device=None): + def __init__(self, config: Qwen2_5_VLTextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: @@ -680,7 +680,7 @@ class Qwen2_5_VLAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -989,7 +989,7 @@ QWEN2_5_VL_ATTENTION_CLASSES = { class Qwen2_5_VLDecoderLayer(nn.Module): - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1077,7 +1077,9 @@ class Qwen2_5_VLDecoderLayer(nn.Module): Qwen2_5_VL_START_DOCSTRING, ) class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): - def __init__(self, config: Qwen2_5_VLConfig): + config_class = Qwen2_5_VLTextConfig + + def __init__(self, config: Qwen2_5_VLTextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1497,9 +1499,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def __init__(self, config): super().__init__(config) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) - self.model = Qwen2_5_VLModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + text_config = config.get_text_config() + self.model = Qwen2_5_VLModel._from_config(text_config) + self.vocab_size = text_config.vocab_size + self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) self.rope_deltas = None # cache rope_deltas here # Initialize weights and apply final processing diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index b47db208627..f34c48bb542 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -28,7 +28,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig from transformers.models.qwen2_vl.modeling_qwen2_vl import ( PatchEmbed, PatchMerger, @@ -110,9 +110,13 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): self.initializer_range = initializer_range +class Qwen2_5_VLTextConfig(Qwen2VLTextConfig): + model_type = "qwen2_5_vl_text" + + class Qwen2_5_VLConfig(Qwen2VLConfig): model_type = "qwen2_5_vl" - sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig} class Qwen2_5_VLMLP(nn.Module): @@ -227,7 +231,7 @@ class Qwen2_5_VLVisionBlock(nn.Module): class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): def _init_weights(self, module): - std = self.config.initializer_range + std = self.config.get_text_config().initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -971,6 +975,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor): __all__ = [ "Qwen2_5_VLConfig", + "Qwen2_5_VLTextConfig", "Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index b03dbc8f0b6..ee2ed40e463 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -56,9 +56,9 @@ class Qwen2VLVisionConfig(PretrainedConfig): self.initializer_range = initializer_range -class Qwen2VLConfig(PretrainedConfig): +class Qwen2VLTextConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a + This is the configuration class to store the configuration of a [`Qwen2VLTextModel`]. It is used to instantiate a Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). @@ -66,7 +66,6 @@ class Qwen2VLConfig(PretrainedConfig): Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: vocab_size (`int`, *optional*, defaults to 152064): Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the @@ -109,8 +108,6 @@ class Qwen2VLConfig(PretrainedConfig): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - vision_config (`Dict`, *optional*): - The config for the visual encoder initialization. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value @@ -150,20 +147,20 @@ class Qwen2VLConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE ```python - >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig + >>> from transformers import Qwen2VLTextModel, Qwen2VLConfig >>> # Initializing a Qwen2VL style configuration >>> configuration = Qwen2VLConfig() >>> # Initializing a model from the Qwen2-VL-7B style configuration - >>> model = Qwen2VLForConditionalGeneration(configuration) + >>> model = Qwen2VLTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" - model_type = "qwen2_vl" - sub_configs = {"vision_config": Qwen2VLVisionConfig} + model_type = "qwen2_vl_text" + base_config_key = "text_config" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Qwen2VL` base_model_tp_plan = { @@ -200,15 +197,9 @@ class Qwen2VLConfig(PretrainedConfig): sliding_window=4096, max_window_layers=80, attention_dropout=0.0, - vision_config=None, rope_scaling=None, **kwargs, ): - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -246,4 +237,67 @@ class Qwen2VLConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -__all__ = ["Qwen2VLConfig"] +class Qwen2VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`): + The config object or dictionary of the vision backbone. + image_token_id (`int`, *optional*, defaults to 151655): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 151656): + The video token index to encode the image prompt. + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_vl" + sub_configs = {"vision_config": Qwen2VLVisionConfig, "text_config": Qwen2VLTextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + # For BC use all kwargs to init `TextConfig` + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + super().__init__(**kwargs) + + +__all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"] diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 96cbb4d3a5c..ebc3740c357 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -44,7 +44,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig +from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig if is_flash_attn_available(): @@ -101,7 +101,7 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput): class Qwen2VLRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen2VLConfig, device=None): + def __init__(self, config: Qwen2VLTextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: @@ -494,7 +494,7 @@ class Qwen2VLAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2VLTextConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -803,7 +803,7 @@ QWEN2_VL_ATTENTION_CLASSES = { class Qwen2VLDecoderLayer(nn.Module): - def __init__(self, config: Qwen2VLConfig, layer_idx: int): + def __init__(self, config: Qwen2VLTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -919,8 +919,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` def _init_weights(self, module): - std = self.config.initializer_range - + std = self.config.get_text_config().initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -1029,7 +1028,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): QWEN2VL_START_DOCSTRING, ) class Qwen2VLModel(Qwen2VLPreTrainedModel): - def __init__(self, config: Qwen2VLConfig): + config_class = Qwen2VLTextConfig + + def __init__(self, config: Qwen2VLTextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1410,9 +1411,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) - self.model = Qwen2VLModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + text_config = config.get_text_config() + self.model = Qwen2VLModel._from_config(text_config) + self.vocab_size = text_config.vocab_size + self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) self.rope_deltas = None # cache rope_deltas here # Initialize weights and apply final processing diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 82e8229f796..a0579ce2029 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -312,6 +312,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test def test_prompt_lookup_decoding_matches_greedy_search(self): super().test_prompt_lookup_decoding_matches_greedy_search() + @unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`") + def test_save_load_fast_init_from_base(self): + pass + @require_torch class Qwen2_5_VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 8c9086f162e..31cf74e7c97 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -316,6 +316,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`") + def test_save_load_fast_init_from_base(self): + pass + @require_torch class Qwen2VLIntegrationTest(unittest.TestCase): diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 7b9744f4a90..a0b31c8cc02 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -345,6 +345,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s # common and important attributes, even if they do not always appear in the modeling files attributes_to_allow = [ + "initializer_range", "bos_index", "eos_index", "pad_index", @@ -355,6 +356,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "image_seq_length", "video_seq_length", "image_size", + "text_config", # may appear as `get_text_config()` "use_cache", "out_features", "out_indices",