mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[qwen-vl] Standardize config (#37268)
* update * fix tests * fixup * update * skip this one * fixup * fix
This commit is contained in:
parent
4f96081aad
commit
3bc44eaaee
@ -232,10 +232,15 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] Qwen2_5_VLConfig
|
||||
|
||||
## Qwen2_5_VLTextConfig
|
||||
|
||||
[[autodoc]] Qwen2_5_VLTextConfig
|
||||
|
||||
## Qwen2_5_VLProcessor
|
||||
|
||||
[[autodoc]] Qwen2_5_VLProcessor
|
||||
|
||||
|
||||
## Qwen2_5_VLModel
|
||||
|
||||
[[autodoc]] Qwen2_5_VLModel
|
||||
|
@ -278,6 +278,10 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] Qwen2VLConfig
|
||||
|
||||
## Qwen2VLTextConfig
|
||||
|
||||
[[autodoc]] Qwen2VLTextConfig
|
||||
|
||||
## Qwen2VLImageProcessor
|
||||
|
||||
[[autodoc]] Qwen2VLImageProcessor
|
||||
|
@ -258,10 +258,12 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("qwen2", "Qwen2Config"),
|
||||
("qwen2_5_omni", "Qwen2_5OmniConfig"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLConfig"),
|
||||
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
|
||||
("qwen2_audio", "Qwen2AudioConfig"),
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
|
||||
("qwen2_moe", "Qwen2MoeConfig"),
|
||||
("qwen2_vl", "Qwen2VLConfig"),
|
||||
("qwen2_vl_text", "Qwen2VLTextConfig"),
|
||||
("qwen3", "Qwen3Config"),
|
||||
("qwen3_moe", "Qwen3MoeConfig"),
|
||||
("rag", "RagConfig"),
|
||||
@ -625,10 +627,12 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("qwen2", "Qwen2"),
|
||||
("qwen2_5_omni", "Qwen2_5Omni"),
|
||||
("qwen2_5_vl", "Qwen2_5_VL"),
|
||||
("qwen2_5_vl_text", "Qwen2_5_VL"),
|
||||
("qwen2_audio", "Qwen2Audio"),
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
||||
("qwen2_moe", "Qwen2MoE"),
|
||||
("qwen2_vl", "Qwen2VL"),
|
||||
("qwen2_vl_text", "Qwen2VL"),
|
||||
("qwen3", "Qwen3"),
|
||||
("qwen3_moe", "Qwen3MoE"),
|
||||
("rag", "RAG"),
|
||||
@ -793,6 +797,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("chinese_clip_vision_model", "chinese_clip"),
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
("granitevision", "llava_next"),
|
||||
("qwen2_5_vl_text", "qwen2_5_vl"),
|
||||
("qwen2_vl_text", "qwen2_vl"),
|
||||
("sam_vision_model", "sam"),
|
||||
("llama4_text", "llama4"),
|
||||
("blip_2_qformer", "blip_2"),
|
||||
|
@ -234,9 +234,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("qdqbert", "QDQBertModel"),
|
||||
("qwen2", "Qwen2Model"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLModel"),
|
||||
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
|
||||
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
||||
("qwen2_moe", "Qwen2MoeModel"),
|
||||
("qwen2_vl", "Qwen2VLModel"),
|
||||
("qwen2_vl_text", "Qwen2VLModel"),
|
||||
("qwen3", "Qwen3Model"),
|
||||
("qwen3_moe", "Qwen3MoeModel"),
|
||||
("recurrent_gemma", "RecurrentGemmaModel"),
|
||||
|
@ -1792,7 +1792,7 @@ QWEN2_5_OMNI_ATTENTION_CLASSES = {
|
||||
|
||||
|
||||
class Qwen2_5OmniDecoderLayer(nn.Module):
|
||||
def __init__(self, config: Qwen2_5OmniConfig, layer_idx: int):
|
||||
def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
|
@ -67,9 +67,9 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLTextModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
@ -77,7 +77,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
||||
@ -120,8 +119,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
@ -161,20 +158,20 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
>>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
>>> model = Qwen2_5_VLTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
||||
model_type = "qwen2_5_vl_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
||||
base_model_tp_plan = {
|
||||
@ -211,15 +208,9 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
@ -257,4 +248,67 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLConfig"]
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151655):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_id (`int`, *optional*, defaults to 151656):
|
||||
The video token index to encode the image prompt.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
elif text_config is None:
|
||||
# For BC use all kwargs to init `TextConfig`
|
||||
self.text_config = self.sub_configs["text_config"](**kwargs)
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLConfig", "Qwen2_5_VLTextConfig"]
|
||||
|
@ -48,7 +48,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
@ -390,7 +390,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
std = self.config.get_text_config().initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
@ -566,7 +566,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
|
||||
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Qwen2_5_VLConfig, device=None):
|
||||
def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
@ -680,7 +680,7 @@ class Qwen2_5_VLAttention(nn.Module):
|
||||
and "Generating Long Sequences with Sparse Transformers".
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -989,7 +989,7 @@ QWEN2_5_VL_ATTENTION_CLASSES = {
|
||||
|
||||
|
||||
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int):
|
||||
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
@ -1077,7 +1077,9 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
Qwen2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
def __init__(self, config: Qwen2_5_VLConfig):
|
||||
config_class = Qwen2_5_VLTextConfig
|
||||
|
||||
def __init__(self, config: Qwen2_5_VLTextConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -1497,9 +1499,11 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
|
||||
self.model = Qwen2_5_VLModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
text_config = config.get_text_config()
|
||||
self.model = Qwen2_5_VLModel._from_config(text_config)
|
||||
self.vocab_size = text_config.vocab_size
|
||||
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
||||
self.rope_deltas = None # cache rope_deltas here
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
|
@ -28,7 +28,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
PatchEmbed,
|
||||
PatchMerger,
|
||||
@ -110,9 +110,13 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2_5_VLTextConfig(Qwen2VLTextConfig):
|
||||
model_type = "qwen2_5_vl_text"
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(Qwen2VLConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig}
|
||||
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLTextConfig}
|
||||
|
||||
|
||||
class Qwen2_5_VLMLP(nn.Module):
|
||||
@ -227,7 +231,7 @@ class Qwen2_5_VLVisionBlock(nn.Module):
|
||||
|
||||
class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
std = self.config.get_text_config().initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
@ -971,6 +975,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
|
||||
|
||||
__all__ = [
|
||||
"Qwen2_5_VLConfig",
|
||||
"Qwen2_5_VLTextConfig",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Qwen2_5_VLModel",
|
||||
"Qwen2_5_VLPreTrainedModel",
|
||||
|
@ -56,9 +56,9 @@ class Qwen2VLVisionConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen2VLConfig(PretrainedConfig):
|
||||
class Qwen2VLTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
||||
This is the configuration class to store the configuration of a [`Qwen2VLTextModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
@ -66,7 +66,6 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
||||
@ -109,8 +108,6 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
@ -150,20 +147,20 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
||||
>>> from transformers import Qwen2VLTextModel, Qwen2VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2VL style configuration
|
||||
>>> configuration = Qwen2VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
||||
>>> model = Qwen2VLTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_vl"
|
||||
sub_configs = {"vision_config": Qwen2VLVisionConfig}
|
||||
model_type = "qwen2_vl_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2VL`
|
||||
base_model_tp_plan = {
|
||||
@ -200,15 +197,9 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
@ -246,4 +237,67 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLConfig"]
|
||||
class Qwen2VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
||||
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLTextConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen2_5_VLVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151655):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_id (`int`, *optional*, defaults to 151656):
|
||||
The video token index to encode the image prompt.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
||||
|
||||
>>> # Initializing a Qwen2_5_VL style configuration
|
||||
>>> configuration = Qwen2_5_VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_vl"
|
||||
sub_configs = {"vision_config": Qwen2VLVisionConfig, "text_config": Qwen2VLTextConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
elif text_config is None:
|
||||
# For BC use all kwargs to init `TextConfig`
|
||||
self.text_config = self.sub_configs["text_config"](**kwargs)
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"]
|
||||
|
@ -44,7 +44,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
@ -101,7 +101,7 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
||||
|
||||
|
||||
class Qwen2VLRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Qwen2VLConfig, device=None):
|
||||
def __init__(self, config: Qwen2VLTextConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
@ -494,7 +494,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
and "Generating Long Sequences with Sparse Transformers".
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config: Qwen2VLTextConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@ -803,7 +803,7 @@ QWEN2_VL_ATTENTION_CLASSES = {
|
||||
|
||||
|
||||
class Qwen2VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
|
||||
def __init__(self, config: Qwen2VLTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
@ -919,8 +919,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
std = self.config.get_text_config().initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
@ -1029,7 +1028,9 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
QWEN2VL_START_DOCSTRING,
|
||||
)
|
||||
class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
def __init__(self, config: Qwen2VLConfig):
|
||||
config_class = Qwen2VLTextConfig
|
||||
|
||||
def __init__(self, config: Qwen2VLTextConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -1410,9 +1411,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
|
||||
self.model = Qwen2VLModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
text_config = config.get_text_config()
|
||||
self.model = Qwen2VLModel._from_config(text_config)
|
||||
self.vocab_size = text_config.vocab_size
|
||||
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
||||
self.rope_deltas = None # cache rope_deltas here
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
|
@ -312,6 +312,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
super().test_prompt_lookup_decoding_matches_greedy_search()
|
||||
|
||||
@unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2_5_VLIntegrationTest(unittest.TestCase):
|
||||
|
@ -316,6 +316,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The base class is LM only and cannot be init with XModelConfig`")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||
|
@ -345,6 +345,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
||||
|
||||
# common and important attributes, even if they do not always appear in the modeling files
|
||||
attributes_to_allow = [
|
||||
"initializer_range",
|
||||
"bos_index",
|
||||
"eos_index",
|
||||
"pad_index",
|
||||
@ -355,6 +356,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
||||
"image_seq_length",
|
||||
"video_seq_length",
|
||||
"image_size",
|
||||
"text_config", # may appear as `get_text_config()`
|
||||
"use_cache",
|
||||
"out_features",
|
||||
"out_indices",
|
||||
|
Loading…
Reference in New Issue
Block a user