[qwen-vl] Standardize config (#37268)

* update

* fix tests

* fixup

* update

* skip this one

* fixup

* fix
This commit is contained in:
Raushan Turganbay 2025-04-17 09:38:12 +02:00 committed by GitHub
parent 4f96081aad
commit 3bc44eaaee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 202 additions and 55 deletions

View File

@ -232,10 +232,15 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2_5_VLConfig
## Qwen2_5_VLTextConfig
[[autodoc]] Qwen2_5_VLTextConfig
## Qwen2_5_VLProcessor
[[autodoc]] Qwen2_5_VLProcessor
## Qwen2_5_VLModel
[[autodoc]] Qwen2_5_VLModel

View File

@ -278,6 +278,10 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2VLConfig
## Qwen2VLTextConfig
[[autodoc]] Qwen2VLTextConfig
## Qwen2VLImageProcessor
[[autodoc]] Qwen2VLImageProcessor

View File

@ -258,10 +258,12 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("qwen2", "Qwen2Config"),
("qwen2_5_omni", "Qwen2_5OmniConfig"),
("qwen2_5_vl", "Qwen2_5_VLConfig"),
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
("qwen2_audio", "Qwen2AudioConfig"),
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
("qwen2_moe", "Qwen2MoeConfig"),
("qwen2_vl", "Qwen2VLConfig"),
("qwen2_vl_text", "Qwen2VLTextConfig"),
("qwen3", "Qwen3Config"),
("qwen3_moe", "Qwen3MoeConfig"),
("rag", "RagConfig"),
@ -625,10 +627,12 @@ MODEL_NAMES_MAPPING = OrderedDict(
("qwen2", "Qwen2"),
("qwen2_5_omni", "Qwen2_5Omni"),
("qwen2_5_vl", "Qwen2_5_VL"),
("qwen2_5_vl_text", "Qwen2_5_VL"),
("qwen2_audio", "Qwen2Audio"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoE"),
("qwen2_vl", "Qwen2VL"),
("qwen2_vl_text", "Qwen2VL"),
("qwen3", "Qwen3"),
("qwen3_moe", "Qwen3MoE"),
("rag", "RAG"),
@ -793,6 +797,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
("sam_vision_model", "sam"),
("llama4_text", "llama4"),
("blip_2_qformer", "blip_2"),

View File

@ -234,9 +234,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLModel"),
("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),

View File

@ -1792,7 +1792,7 @@ QWEN2_5_OMNI_ATTENTION_CLASSES = {
class Qwen2_5OmniDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int):
def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

View File

@ -67,9 +67,9 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.initializer_range = initializer_range
class Qwen2_5_VLConfig(PretrainedConfig):
class Qwen2_5_VLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
@ -77,7 +77,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
@ -120,8 +119,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@ -161,20 +158,20 @@ class Qwen2_5_VLConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> model = Qwen2_5_VLTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
model_type = "qwen2_5_vl_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2_5_VL`
base_model_tp_plan = {
@ -211,15 +208,9 @@ class Qwen2_5_VLConfig(PretrainedConfig):
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@ -257,4 +248,67 @@ class Qwen2_5_VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen2_5_VLConfig"]
class Qwen2_5_VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
self.image_token_id = image_token_id
self.video_token_id = video_token_id
super().__init__(**kwargs)
__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]

View File

@ -48,7 +48,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
if is_flash_attn_available():
@ -390,7 +390,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
def _init_weights(self, module):
std = self.config.initializer_range
std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
@ -566,7 +566,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None):
def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@ -680,7 +680,7 @@ class Qwen2_5_VLAttention(nn.Module):
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None):
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@ -989,7 +989,7 @@ QWEN2_5_VL_ATTENTION_CLASSES = {
class Qwen2_5_VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int):
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -1077,7 +1077,9 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
Qwen2_5_VL_START_DOCSTRING,
)
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
def __init__(self, config: Qwen2_5_VLConfig):
config_class = Qwen2_5_VLTextConfig
def __init__(self, config: Qwen2_5_VLTextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -1497,9 +1499,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2_5_VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
text_config = config.get_text_config()
self.model = Qwen2_5_VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing

View File

@ -28,7 +28,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
PatchEmbed,
PatchMerger,
@ -110,9 +110,13 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.initializer_range = initializer_range
class Qwen2_5_VLTextConfig(Qwen2VLTextConfig):
model_type = "qwen2_5_vl_text"
class Qwen2_5_VLConfig(Qwen2VLConfig):
model_type = "qwen2_5_vl"
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
class Qwen2_5_VLMLP(nn.Module):
@ -227,7 +231,7 @@ class Qwen2_5_VLVisionBlock(nn.Module):
class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
@ -971,6 +975,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
__all__ = [
"Qwen2_5_VLConfig",
"Qwen2_5_VLTextConfig",
"Qwen2_5_VLForConditionalGeneration",
"Qwen2_5_VLModel",
"Qwen2_5_VLPreTrainedModel",

View File

@ -56,9 +56,9 @@ class Qwen2VLVisionConfig(PretrainedConfig):
self.initializer_range = initializer_range
class Qwen2VLConfig(PretrainedConfig):
class Qwen2VLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
This is the configuration class to store the configuration of a [`Qwen2VLTextModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
@ -66,7 +66,6 @@ class Qwen2VLConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 152064):
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
@ -109,8 +108,6 @@ class Qwen2VLConfig(PretrainedConfig):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
vision_config (`Dict`, *optional*):
The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
@ -150,20 +147,20 @@ class Qwen2VLConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
>>> from transformers import Qwen2VLTextModel, Qwen2VLConfig
>>> # Initializing a Qwen2VL style configuration
>>> configuration = Qwen2VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2VLForConditionalGeneration(configuration)
>>> model = Qwen2VLTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_vl"
sub_configs = {"vision_config": Qwen2VLVisionConfig}
model_type = "qwen2_vl_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2VL`
base_model_tp_plan = {
@ -200,15 +197,9 @@ class Qwen2VLConfig(PretrainedConfig):
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@ -246,4 +237,67 @@ class Qwen2VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen2VLConfig"]
class Qwen2VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
```python
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
>>> # Initializing a Qwen2_5_VL style configuration
>>> configuration = Qwen2_5_VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_vl"
sub_configs = {"vision_config": Qwen2VLVisionConfig, "text_config": Qwen2VLTextConfig}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
self.image_token_id = image_token_id
self.video_token_id = video_token_id
super().__init__(**kwargs)
__all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"]

View File

@ -44,7 +44,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
if is_flash_attn_available():
@ -101,7 +101,7 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
class Qwen2VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2VLConfig, device=None):
def __init__(self, config: Qwen2VLTextConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@ -494,7 +494,7 @@ class Qwen2VLAttention(nn.Module):
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
def __init__(self, config: Qwen2VLTextConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@ -803,7 +803,7 @@ QWEN2_VL_ATTENTION_CLASSES = {
class Qwen2VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
def __init__(self, config: Qwen2VLTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
@ -919,8 +919,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
def _init_weights(self, module):
std = self.config.initializer_range
std = self.config.get_text_config().initializer_range
if isinstance(module, (nn.Linear, nn.Conv3d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
@ -1029,7 +1028,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
QWEN2VL_START_DOCSTRING,
)
class Qwen2VLModel(Qwen2VLPreTrainedModel):
def __init__(self, config: Qwen2VLConfig):
config_class = Qwen2VLTextConfig
def __init__(self, config: Qwen2VLTextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -1410,9 +1411,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
text_config = config.get_text_config()
self.model = Qwen2VLModel._from_config(text_config)
self.vocab_size = text_config.vocab_size
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing

View File

@ -312,6 +312,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_prompt_lookup_decoding_matches_greedy_search(self):
super().test_prompt_lookup_decoding_matches_greedy_search()
@unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`")
def test_save_load_fast_init_from_base(self):
pass
@require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase):

View File

@ -316,6 +316,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`")
def test_save_load_fast_init_from_base(self):
pass
@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@ -345,6 +345,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
# common and important attributes, even if they do not always appear in the modeling files
attributes_to_allow = [
"initializer_range",
"bos_index",
"eos_index",
"pad_index",
@ -355,6 +356,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"image_seq_length",
"video_seq_length",
"image_size",
"text_config", # may appear as `get_text_config()`
"use_cache",
"out_features",
"out_indices",