[qwen-vl] Standardize config (#37268)

* update

* fix tests

* fixup

* update

* skip this one

* fixup

* fix
This commit is contained in:
Raushan Turganbay 2025-04-17 09:38:12 +02:00 committed by GitHub
parent 4f96081aad
commit 3bc44eaaee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 202 additions and 55 deletions

View File

@ -232,10 +232,15 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2_5_VLConfig [[autodoc]] Qwen2_5_VLConfig
## Qwen2_5_VLTextConfig
[[autodoc]] Qwen2_5_VLTextConfig
## Qwen2_5_VLProcessor ## Qwen2_5_VLProcessor
[[autodoc]] Qwen2_5_VLProcessor [[autodoc]] Qwen2_5_VLProcessor
## Qwen2_5_VLModel ## Qwen2_5_VLModel
[[autodoc]] Qwen2_5_VLModel [[autodoc]] Qwen2_5_VLModel

View File

@ -278,6 +278,10 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2VLConfig [[autodoc]] Qwen2VLConfig
## Qwen2VLTextConfig
[[autodoc]] Qwen2VLTextConfig
## Qwen2VLImageProcessor ## Qwen2VLImageProcessor
[[autodoc]] Qwen2VLImageProcessor [[autodoc]] Qwen2VLImageProcessor

View File

@ -258,10 +258,12 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("qwen2", "Qwen2Config"), ("qwen2", "Qwen2Config"),
("qwen2_5_omni", "Qwen2_5OmniConfig"), ("qwen2_5_omni", "Qwen2_5OmniConfig"),
("qwen2_5_vl", "Qwen2_5_VLConfig"), ("qwen2_5_vl", "Qwen2_5_VLConfig"),
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
("qwen2_audio", "Qwen2AudioConfig"), ("qwen2_audio", "Qwen2AudioConfig"),
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
("qwen2_moe", "Qwen2MoeConfig"), ("qwen2_moe", "Qwen2MoeConfig"),
("qwen2_vl", "Qwen2VLConfig"), ("qwen2_vl", "Qwen2VLConfig"),
("qwen2_vl_text", "Qwen2VLTextConfig"),
("qwen3", "Qwen3Config"), ("qwen3", "Qwen3Config"),
("qwen3_moe", "Qwen3MoeConfig"), ("qwen3_moe", "Qwen3MoeConfig"),
("rag", "RagConfig"), ("rag", "RagConfig"),
@ -625,10 +627,12 @@ MODEL_NAMES_MAPPING = OrderedDict(
("qwen2", "Qwen2"), ("qwen2", "Qwen2"),
("qwen2_5_omni", "Qwen2_5Omni"), ("qwen2_5_omni", "Qwen2_5Omni"),
("qwen2_5_vl", "Qwen2_5_VL"), ("qwen2_5_vl", "Qwen2_5_VL"),
("qwen2_5_vl_text", "Qwen2_5_VL"),
("qwen2_audio", "Qwen2Audio"), ("qwen2_audio", "Qwen2Audio"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoE"), ("qwen2_moe", "Qwen2MoE"),
("qwen2_vl", "Qwen2VL"), ("qwen2_vl", "Qwen2VL"),
("qwen2_vl_text", "Qwen2VL"),
("qwen3", "Qwen3"), ("qwen3", "Qwen3"),
("qwen3_moe", "Qwen3MoE"), ("qwen3_moe", "Qwen3MoE"),
("rag", "RAG"), ("rag", "RAG"),
@ -793,6 +797,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("chinese_clip_vision_model", "chinese_clip"), ("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"), ("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"), ("granitevision", "llava_next"),
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
("sam_vision_model", "sam"), ("sam_vision_model", "sam"),
("llama4_text", "llama4"), ("llama4_text", "llama4"),
("blip_2_qformer", "blip_2"), ("blip_2_qformer", "blip_2"),

View File

@ -234,9 +234,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"), ("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"), ("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"), ("qwen2_5_vl", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"), ("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"), ("qwen2_vl", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLModel"),
("qwen3", "Qwen3Model"), ("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"), ("qwen3_moe", "Qwen3MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"), ("recurrent_gemma", "RecurrentGemmaModel"),

View File

@ -1792,7 +1792,7 @@ QWEN2_5_OMNI_ATTENTION_CLASSES = {
class Qwen2_5OmniDecoderLayer(nn.Module): class Qwen2_5OmniDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int): def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size

View File

@ -67,9 +67,9 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
class Qwen2_5_VLConfig(PretrainedConfig): class Qwen2_5_VLTextConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
@ -77,7 +77,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 152064): vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
@ -120,8 +119,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@ -161,20 +158,20 @@ class Qwen2_5_VLConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python ```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig >>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration >>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig() >>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration >>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration) >>> model = Qwen2_5_VLTextModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "qwen2_5_vl" model_type = "qwen2_5_vl_text"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2_5_VL` # Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan = { base_model_tp_plan = {
@ -211,15 +208,9 @@ class Qwen2_5_VLConfig(PretrainedConfig):
sliding_window=4096, sliding_window=4096,
max_window_layers=80, max_window_layers=80,
attention_dropout=0.0, attention_dropout=0.0,
vision_config=None,
rope_scaling=None, rope_scaling=None,
**kwargs, **kwargs,
): ):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -257,4 +248,67 @@ class Qwen2_5_VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen2_5_VLConfig"] class Qwen2_5_VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
self.image_token_id = image_token_id
self.video_token_id = video_token_id
super().__init__(**kwargs)
__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]

View File

@ -48,7 +48,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
if is_flash_attn_available(): if is_flash_attn_available():
@ -390,7 +390,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)): if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
@ -566,7 +566,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
class Qwen2_5_VLRotaryEmbedding(nn.Module): class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None): def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
super().__init__() super().__init__()
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@ -680,7 +680,7 @@ class Qwen2_5_VLAttention(nn.Module):
and "Generating Long Sequences with Sparse Transformers". and "Generating Long Sequences with Sparse Transformers".
""" """
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -989,7 +989,7 @@ QWEN2_5_VL_ATTENTION_CLASSES = {
class Qwen2_5_VLDecoderLayer(nn.Module): class Qwen2_5_VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -1077,7 +1077,9 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
Qwen2_5_VL_START_DOCSTRING, Qwen2_5_VL_START_DOCSTRING,
) )
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
def __init__(self, config: Qwen2_5_VLConfig): config_class = Qwen2_5_VLTextConfig
def __init__(self, config: Qwen2_5_VLTextConfig):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -1497,9 +1499,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2_5_VLModel(config)
self.vocab_size = config.vocab_size text_config = config.get_text_config()
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.model = Qwen2_5_VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing # Initialize weights and apply final processing

View File

@ -28,7 +28,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import ( from transformers.models.qwen2_vl.modeling_qwen2_vl import (
PatchEmbed, PatchEmbed,
PatchMerger, PatchMerger,
@ -110,9 +110,13 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
class Qwen2_5_VLTextConfig(Qwen2VLTextConfig):
model_type = "qwen2_5_vl_text"
class Qwen2_5_VLConfig(Qwen2VLConfig): class Qwen2_5_VLConfig(Qwen2VLConfig):
model_type = "qwen2_5_vl" model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
class Qwen2_5_VLMLP(nn.Module): class Qwen2_5_VLMLP(nn.Module):
@ -227,7 +231,7 @@ class Qwen2_5_VLVisionBlock(nn.Module):
class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)): if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
@ -971,6 +975,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
__all__ = [ __all__ = [
"Qwen2_5_VLConfig", "Qwen2_5_VLConfig",
"Qwen2_5_VLTextConfig",
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen2_5_VLModel", "Qwen2_5_VLModel",
"Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLPreTrainedModel",

View File

@ -56,9 +56,9 @@ class Qwen2VLVisionConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
class Qwen2VLConfig(PretrainedConfig): class Qwen2VLTextConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`Qwen2VLTextModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
@ -66,7 +66,6 @@ class Qwen2VLConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 152064): vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
@ -109,8 +108,6 @@ class Qwen2VLConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@ -150,20 +147,20 @@ class Qwen2VLConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python ```python
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig >>> from transformers import Qwen2VLTextModel, Qwen2VLConfig
>>> # Initializing a Qwen2VL style configuration >>> # Initializing a Qwen2VL style configuration
>>> configuration = Qwen2VLConfig() >>> configuration = Qwen2VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration >>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2VLForConditionalGeneration(configuration) >>> model = Qwen2VLTextModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "qwen2_vl" model_type = "qwen2_vl_text"
sub_configs = {"vision_config": Qwen2VLVisionConfig} base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2VL` # Default tensor parallel plan for base model `Qwen2VL`
base_model_tp_plan = { base_model_tp_plan = {
@ -200,15 +197,9 @@ class Qwen2VLConfig(PretrainedConfig):
sliding_window=4096, sliding_window=4096,
max_window_layers=80, max_window_layers=80,
attention_dropout=0.0, attention_dropout=0.0,
vision_config=None,
rope_scaling=None, rope_scaling=None,
**kwargs, **kwargs,
): ):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -246,4 +237,67 @@ class Qwen2VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen2VLConfig"] class Qwen2VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_vl"
sub_configs = {"vision_config": Qwen2VLVisionConfig, "text_config": Qwen2VLTextConfig}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
self.image_token_id = image_token_id
self.video_token_id = video_token_id
super().__init__(**kwargs)
__all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"]

View File

@ -44,7 +44,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
if is_flash_attn_available(): if is_flash_attn_available():
@ -101,7 +101,7 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
class Qwen2VLRotaryEmbedding(nn.Module): class Qwen2VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2VLConfig, device=None): def __init__(self, config: Qwen2VLTextConfig, device=None):
super().__init__() super().__init__()
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@ -494,7 +494,7 @@ class Qwen2VLAttention(nn.Module):
and "Generating Long Sequences with Sparse Transformers". and "Generating Long Sequences with Sparse Transformers".
""" """
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): def __init__(self, config: Qwen2VLTextConfig, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
@ -803,7 +803,7 @@ QWEN2_VL_ATTENTION_CLASSES = {
class Qwen2VLDecoderLayer(nn.Module): class Qwen2VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen2VLConfig, layer_idx: int): def __init__(self, config: Qwen2VLTextConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -919,8 +919,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)): if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
@ -1029,7 +1028,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
QWEN2VL_START_DOCSTRING, QWEN2VL_START_DOCSTRING,
) )
class Qwen2VLModel(Qwen2VLPreTrainedModel): class Qwen2VLModel(Qwen2VLPreTrainedModel):
def __init__(self, config: Qwen2VLConfig): config_class = Qwen2VLTextConfig
def __init__(self, config: Qwen2VLTextConfig):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -1410,9 +1411,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2VLModel(config)
self.vocab_size = config.vocab_size text_config = config.get_text_config()
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.model = Qwen2VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing # Initialize weights and apply final processing

View File

@ -312,6 +312,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
super().test_prompt_lookup_decoding_matches_greedy_search() super().test_prompt_lookup_decoding_matches_greedy_search()
@unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`")
def test_save_load_fast_init_from_base(self):
pass
@require_torch @require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase): class Qwen2_5_VLIntegrationTest(unittest.TestCase):

View File

@ -316,6 +316,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`")
def test_save_load_fast_init_from_base(self):
pass
@require_torch @require_torch
class Qwen2VLIntegrationTest(unittest.TestCase): class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@ -345,6 +345,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
# common and important attributes, even if they do not always appear in the modeling files # common and important attributes, even if they do not always appear in the modeling files
attributes_to_allow = [ attributes_to_allow = [
"initializer_range",
"bos_index", "bos_index",
"eos_index", "eos_index",
"pad_index", "pad_index",
@ -355,6 +356,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"image_seq_length", "image_seq_length",
"video_seq_length", "video_seq_length",
"image_size", "image_size",
"text_config", # may appear as `get_text_config()`
"use_cache", "use_cache",
"out_features", "out_features",
"out_indices", "out_indices",