Add InternVL (2.5 MPO) (#35968)

* initial commit

* add convert internvl

* add first end-to-end working internvl

* nit prompt and image proc

* add working chat template

* add conversion llama-based models

* add tests

* pass all tests

* fix isort

* fix modular after main merge

* add video processing for internvl

* add support for interlaced images and videos

* Remove processing and config from modular, add more tests

* add llama model tests

* Modify processor for compatibility with refactored got ocr image processor

* add comments in processor

* Add docs and nits

* change video processing to use custom sample_indices_fn

* rebase and fix tests

* add processor tests

* Add changes Raushan review

* Use the new attention interface for the vision model

* nits

* add support for custom video_load_backend

* remove mention to InternVLTokenizer

* refactor vision model to simplify logic

* refactor processor for better readibility

* fix copies

* fix require av processor test

* refactor internVL vision

* Update processor and fix processing tests

* fix docstring

* update convert_weights for internvl3

* change image processor to fast by default

* remove do_center_crop=True in convert_weights

* force use_cache to True

* push_to_hub before reloading

* fix internVLVision for larger models

* update convert weight for qk norm

* fix convert_weights

* fix eos_token_id in convert

* update docs and integration tests

* make modifs after review

* fix wrong k_norm and reduce modular

* change image_token_index to image_token_id

* change checkpoint to OpenGVLab org

* last nits

* explicitely del self.num_key_value_groups

* add extra special tokens
This commit is contained in:
Yoni Gozlan 2025-04-18 18:57:33 +02:00 committed by GitHub
parent b0c6ff5e13
commit a245011252
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 4447 additions and 5 deletions

View File

@ -953,6 +953,8 @@
title: InstructBLIP
- local: model_doc/instructblipvideo
title: InstructBlipVideo
- local: model_doc/internvl
title: InternVL
- local: model_doc/janus
title: Janus
- local: model_doc/kosmos-2

View File

@ -0,0 +1,349 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
</div>
</div>
# InternVL
The InternVL3 family of Visual Language Models was introduced in [InternVL3: Exploring Advanced Training and Test-Time Recipes for Open-Source Multimodal Models](https://huggingface.co/papers/2504.10479).
The abstract from the paper is the following:
*We introduce InternVL3, a significant advancement in the InternVL series featuring a native multimodal pre-training paradigm. Rather than adapting a text-only large language model (LLM) into a multimodal large language model (MLLM) that supports visual inputs, InternVL3 jointly acquires multimodal and linguistic capabilities from both diverse multimodal data and pure-text corpora during a single pre-training stage. This unified training paradigm effectively addresses the complexities and alignment challenges commonly encountered in conventional post-hoc training pipelines for MLLMs. To further improve performance and scalability, InternVL3 incorporates variable visual position encoding (V2PE) to support extended multimodal contexts, employs advanced post-training techniques such as supervised fine-tuning (SFT) and mixed preference optimization (MPO), and adopts test-time scaling strategies alongside an optimized training infrastructure. Extensive empirical evaluations demonstrate that InternVL3 delivers superior performance across a wide range of multi-modal tasks. In particular, InternVL3-78B achieves a score of 72.2 on the MMMU benchmark, setting a new state-of-the-art among open-source MLLMs. Its capabilities remain highly competitive with leading proprietary models, including ChatGPT-4o, Claude 3.5 Sonnet, and Gemini 2.5 Pro, while also maintaining strong pure-language proficiency. In pursuit of open-science principles, we will publicly release both the training data and model weights to foster further research and development in next-generation MLLMs.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/internvl_architecture.png" alt="drawing" width="600"/>
<small> Overview of InternVL3 models architecture, which is the same as InternVL2.5. Taken from the <a href="https://huggingface.co/OpenGVLab/InternVL3-1B">original checkpoint.</a> </small>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/internvl_overview_performance.png" alt="drawing" width="600"/>
<small> Comparison of InternVL3 performance on OpenCompass against other SOTA VLLMs. Taken from the <a href="https://huggingface.co/OpenGVLab/InternVL3-1B">original checkpoint.</a> </small>
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
The original code can be found [here](https://github.com/OpenGVLab/InternVL).
## Usage example
### Inference with Pipeline
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `InternVL3` models in just a few lines of code:
```python
>>> from transformers import pipeline
>>> messages = [
... {
... "role": "user",
... "content": [
... {
... "type": "image",
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
... },
... {"type": "text", "text": "Describe this image."},
... ],
... },
... ]
>>> pipe = pipeline("image-text-to-text", model="OpenGVLab/InternVL3-1B-hf")
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
>>> outputs[0]["generated_text"]
'The image showcases a vibrant scene of nature, featuring several flowers and a bee. \n\n1. **Foreground Flowers**: \n - The primary focus is on a large, pink cosmos flower with a prominent yellow center. The petals are soft and slightly r'
```
### Inference on a single image
This example demonstrates how to perform inference on a single image with the InternVL models using chat templates.
> [!NOTE]
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
... {"type": "text", "text": "Please describe the image explicitly."},
... ],
... }
... ]
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
>>> generate_ids = model.generate(**inputs, max_new_tokens=50)
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
>>> decoded_output
'The image shows two cats lying on a pink blanket. The cat on the left is a tabby with a mix of brown, black, and white fur, and it appears to be sleeping with its head resting on the blanket. The cat on the'
```
### Text-only generation
This example shows how to generate text using the InternVL model without providing any image input.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "Write a haiku"},
... ],
... }
... ]
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
>>> generate_ids = model.generate(**inputs, max_new_tokens=50)
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
>>> print(decoded_output)
"Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."
```
### Batched image and text inputs
InternVL models also support batched image and text inputs.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
... [
... {
... "role": "user",
... "content": [
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
... {"type": "text", "text": "Write a haiku for this image"},
... ],
... },
... ],
... [
... {
... "role": "user",
... "content": [
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
... {"type": "text", "text": "Describe this image"},
... ],
... },
... ],
... ]
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
>>> output = model.generate(**inputs, max_new_tokens=25)
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
>>> decoded_outputs
["user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace.",
'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of']
```
### Batched multi-image input
This implementation of the InternVL models supports batched text-images inputs with different number of images for each text.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
...                 {"type": "text", "text": "Write a haiku for this image"},
...             ],
...         },
...     ],
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
...                 {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
...                 {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
...             ],
...         },
...     ],
>>> ]
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
>>> output = model.generate(**inputs, max_new_tokens=25)
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
>>> decoded_outputs
["user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace.",
'user\n\n\nThese images depict two different landmarks. Can you identify them?\nassistant\nYes, these images depict the Statue of Liberty and the Golden Gate Bridge.']
```
### Video input
InternVL models can also handle video inputs. Here is an example of how to perform inference on a video input using chat templates.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
>>> model_checkpoint = "OpenGVLab/InternVL3-8B-hf"
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, quantization_config=quantization_config)
>>> messages = [
... {
... "role": "user",
... "content": [
... {
... "type": "video",
... "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
... },
... {"type": "text", "text": "What type of shot is the man performing?"},
... ],
... }
>>> ]
>>> inputs = processor.apply_chat_template(
... messages,
... return_tensors="pt",
... add_generation_prompt=True,
... tokenize=True,
... return_dict=True,
>>> ).to(model.device, dtype=torch.float16)
>>> output = model.generate(**inputs, max_new_tokens=25)
>>> decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
>>> decoded_output
'The man is performing a forehand shot.'
```
### Interleaved image and video inputs
This example showcases how to handle a batch of chat conversations with interleaved image and video inputs using chat template.
```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
>>> import torch
>>> torch_device = "cuda"
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
>>> messages = [
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
...                 {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
...                 {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
...             ],
...         },
...     ],
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "video", "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"},
...                 {"type": "text", "text": "What type of shot is the man performing?"},
...             ],
...         },
...     ],
...     [
...         {
...             "role": "user",
...             "content": [
...                 {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
...                 {"type": "text", "text": "Write a haiku for this image"},
...             ],
...         },
...     ],
>>> ]
>>> inputs = processor.apply_chat_template(
...     messages,
...     padding=True,
... add_generation_prompt=True,
... tokenize=True,
... return_dict=True,
...     return_tensors="pt",
>>> ).to(model.device, dtype=torch.bfloat16)
>>> outputs = model.generate(**inputs, max_new_tokens=25)
>>> decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)
>>> decoded_outputs
['user\n\n\nThese images depict two different landmarks. Can you identify them?\nassistant\nThe images depict the Statue of Liberty and the Golden Gate Bridge.',
'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot',
"user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace."]
```
## InternVLVisionConfig
[[autodoc]] InternVLVisionConfig
## InternVLConfig
[[autodoc]] InternVLConfig
## InternVLVisionModel
[[autodoc]] InternVLVisionModel
- forward
## InternVLForConditionalGeneration
[[autodoc]] InternVLForConditionalGeneration
- forward
## InternVLProcessor
[[autodoc]] InternVLProcessor

View File

@ -18,7 +18,7 @@ from collections.abc import Iterable
from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import Callable, Optional, Union
from urllib.parse import urlparse
import numpy as np
@ -77,9 +77,8 @@ if is_vision_available():
pil_torch_interpolation_mapping = {}
if TYPE_CHECKING:
if is_torch_available():
import torch
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -162,6 +161,15 @@ def is_valid_list_of_images(images: list):
return images and all(is_valid_image(image) for image in images)
def concatenate_list(input_list):
if isinstance(input_list[0], list):
return [item for sublist in input_list for item in sublist]
elif isinstance(input_list[0], np.ndarray):
return np.concatenate(input_list, axis=0)
elif isinstance(input_list[0], torch.Tensor):
return torch.cat(input_list, dim=0)
def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):

View File

@ -143,6 +143,7 @@ if TYPE_CHECKING:
from .informer import *
from .instructblip import *
from .instructblipvideo import *
from .internvl import *
from .jamba import *
from .janus import *
from .jetmoe import *

View File

@ -162,6 +162,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
("instructblipvideo", "InstructBlipVideoConfig"),
("internvl", "InternVLConfig"),
("internvl_vision", "InternVLVisionConfig"),
("jamba", "JambaConfig"),
("janus", "JanusConfig"),
("jetmoe", "JetMoeConfig"),
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("informer", "Informer"),
("instructblip", "InstructBLIP"),
("instructblipvideo", "InstructBlipVideo"),
("internvl", "InternVL"),
("internvl_vision", "InternVLVision"),
("jamba", "Jamba"),
("janus", "Janus"),
("jetmoe", "JetMoe"),
@ -797,6 +801,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("internvl_vision", "internvl"),
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
("sam_vision_model", "sam"),

View File

@ -151,6 +151,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("internvl_vision", "InternVLVisionModel"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
("jetmoe", "JetMoeModel"),
@ -862,6 +863,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("internvl", "InternVLForConditionalGeneration"),
("janus", "JanusForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llama4", "Llama4ForConditionalGeneration"),

View File

@ -75,6 +75,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("idefics3", "Idefics3Processor"),
("instructblip", "InstructBlipProcessor"),
("instructblipvideo", "InstructBlipVideoProcessor"),
("internvl", "InternVLProcessor"),
("janus", "JanusProcessor"),
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),

View File

@ -258,6 +258,7 @@ else:
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
(
"jamba",
(

View File

@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_internvl import *
from .modeling_internvl import *
from .processing_internvl import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,225 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig
class InternVLVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVLVisionModel`]. It is used to instantiate an InternVLVisionModel
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
a similar configuration to that of the InternVL3-1B.
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries, keys and values.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to apply normalization to the queries and keys before the attention operation.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for attention weights.
projection_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for the projection layer.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use in the encoder. Can be `"layer_norm"` or `"rms_norm"`.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
image_size (`int` or `list[int]`, *optional*, defaults to `[448, 448]`):
The size (resolution) of each image.
patch_size (`int` or `list[int]`, *optional*, defaults to `[14, 14]`):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
use_mask_token (`bool`, *optional*, defaults to `False`):
Whether to use a mask token for masked image modeling.
use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
Whether to use BERT-style absolute position embeddings.
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
use_mean_pooling (`bool`, *optional*, defaults to `True`):
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
CLS token, before applying the classification head.
Example:
```python
>>> from transformers import InternVLVisionConfig, InternVLVisionModel
>>> # Initializing a InternVLVisionModel OpenGVLab/InternVL3-1B-hf style configuration
>>> configuration = InternVLVisionConfig()
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
>>> model = InternVLVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "internvl_vision"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
attention_bias=False,
use_qk_norm=False,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_dropout=0.0,
projection_dropout=0.0,
initializer_range=0.02,
norm_type="layer_norm",
layer_norm_eps=1e-06,
image_size=[448, 448],
patch_size=[14, 14],
num_channels=3,
use_mask_token=False,
use_absolute_position_embeddings=True,
layer_scale_init_value=0.1,
use_mean_pooling=True,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_bias = attention_bias
self.use_qk_norm = use_qk_norm
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_dropout = attention_dropout
self.projection_dropout = projection_dropout
self.initializer_range = initializer_range
self.norm_type = norm_type
self.layer_norm_eps = layer_norm_eps
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.use_mask_token = use_mask_token
self.use_absolute_position_embeddings = use_absolute_position_embeddings
self.layer_scale_init_value = layer_scale_init_value
self.use_mean_pooling = use_mean_pooling
class InternVLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVLForConditionalGeneration`]. It is used to instantiate a
InternVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of InternVL3-1B.
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `InternVisonConfig`):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
The config object or dictionary of the text backbone.
image_token_id (`int`, *optional*, defaults to 151667):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 256):
Number of image tokens to use per image patch.
downsample_ratio (`float`, *optional*, defaults to 0.5):
Factor by which to downsample the image.
projector_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the projector.
vision_feature_layer (`int`, *optional*, defaults to -1):
The index of the layer to use as the image features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
```python
>>> from transformers import InternVLForConditionalGeneration, InternVLConfig
>>> # Initializing a InternVL style configuration
>>> configuration = InternVLConfig()
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
>>> model = InternVLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "internvl"
sub_configs = {"text_config": AutoConfig, "vision_config": InternVLVisionConfig}
def __init__(
self,
vision_config=None,
text_config=None,
image_token_id=151667,
image_seq_length=256,
downsample_ratio=0.5,
projector_hidden_act="gelu",
vision_feature_layer=-1,
vision_feature_select_strategy="default",
**kwargs,
):
self.image_token_id = image_token_id
self.image_seq_length = image_seq_length
self.downsample_ratio = downsample_ratio
self.projector_hidden_act = projector_hidden_act
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy
if isinstance(vision_config, dict):
self.vision_config = InternVLVisionConfig(**vision_config)
elif isinstance(vision_config, InternVLVisionConfig):
self.vision_config = vision_config
elif vision_config is None:
self.vision_config = InternVLVisionConfig()
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["qwen2"]()
self.text_config = text_config
super().__init__(**kwargs)
__all__ = ["InternVLVisionConfig", "InternVLConfig"]

View File

@ -0,0 +1,417 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
import torch
from einops import rearrange
from transformers import (
AutoModel,
AutoTokenizer,
GenerationConfig,
GotOcr2ImageProcessorFast,
InternVLConfig,
InternVLForConditionalGeneration,
InternVLProcessor,
InternVLVisionConfig,
LlamaConfig,
Qwen2Config,
)
LM_TYPE_CORRESPONDENCE = {
"OpenGVLab/InternVL2_5-1B-MPO": "qwen2",
"OpenGVLab/InternVL2_5-2B-MPO": "llama",
"OpenGVLab/InternVL2_5-4B-MPO": "qwen2",
"OpenGVLab/InternVL2_5-8B-MPO": "llama",
"OpenGVLab/InternVL2_5-26B-MPO": "llama",
"OpenGVLab/InternVL2_5-38B-MPO": "qwen2",
"OpenGVLab/InternVL2_5-78B-MPO": "qwen2",
"OpenGVLab/InternVL3-1B": "qwen2",
"OpenGVLab/InternVL3-2B": "qwen2",
"OpenGVLab/InternVL3-8B": "qwen2",
"OpenGVLab/InternVL3-9B": "llama",
"OpenGVLab/InternVL3-14B": "qwen2",
"OpenGVLab/InternVL3-38B": "qwen2",
"OpenGVLab/InternVL3-78B": "qwen2",
}
UNNECESSARY_CONFIG_KEYS = [ "_name_or_path", "_attn_implementation_autoset", "auto_map", "use_bfloat16", "use_flash_attn", "bias", "laux_allreduce", "moe_coeff_ratio", "moe_intermediate_size", "moe_output_scale", "noisy_gate_policy", "shared_expert_intermediate_size", "use_residual", "use_moe", "use_rts", "use_weighted_residual", "moe_config", "num_experts", "num_routed_experts", "num_shared_experts", "capacity_factor", "eval_capacity_factor", "drop_path_rate"] # fmt: skip
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION = {
# Vision encoder mapping
r"vision_model": r"vision_tower",
r"layers": r"layer",
r"class_embedding": r"cls_token",
r"position_embedding": r"position_embeddings",
r"patch_embedding": r"patch_embeddings.projection",
r"ls(\d+)": r"lambda_\1",
r"attn.proj": r"attention.projection_layer",
r"attn.dropout": r"attention.projection_dropout",
r"attn": r"attention",
r"norm1": r"layernorm_before",
r"norm2": r"layernorm_after",
}
ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA = {
# Vision encoder mapping
r"tok_embeddings": r"embed_tokens",
r"attention.wo": r"self_attn.o_proj",
r"feed_forward.w1": r"mlp.gate_proj",
r"feed_forward.w2": r"mlp.down_proj",
r"feed_forward.w3": r"mlp.up_proj",
r"attention_norm": r"input_layernorm",
r"ffn_norm": r"post_attention_layernorm",
r"output": r"lm_head",
}
ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI = {
# Vision encoder mapping
r"mlp1.0": r"multi_modal_projector.layer_norm",
r"mlp1.1": r"multi_modal_projector.linear_1",
r"mlp1.3": r"multi_modal_projector.linear_2",
}
chat_template = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% if message['content'] is string %}"
"{{ message['content'] }}"
"{% else %}"
"{% for content in message['content'] %}"
"{% if content['type'] == 'image' %}"
"{{ '<image>\n' }}"
"{% elif content['type'] == 'video' %}"
"{{ '<video>\n' }}"
"{% elif content['type'] == 'text' %}"
"{{ content['text'] }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"{{'<|im_end|>\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{'<|im_start|>assistant\n' }}"
"{% endif %}"
)
# fmt: on
CONTEXT_LENGTH = 8192
def convert_old_keys_to_new_keys(state_dict_keys: dict = None, path: str = None):
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
the key mappings.
"""
output_dict = {}
if state_dict_keys is not None:
old_text_vision = "\n".join([key for key in state_dict_keys if key.startswith("vision_model")])
new_text = old_text_vision
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION.items():
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n")))
old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")])
new_text = old_text_language
if LM_TYPE_CORRESPONDENCE[path] == "llama":
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items():
new_text = re.sub(pattern, replacement, new_text)
output_dict.update(dict(zip(old_text_language.split("\n"), new_text.split("\n"))))
old_text_multi = "\n".join(
[
key
for key in state_dict_keys
if not (key.startswith("language_model") or key.startswith("vision_model"))
]
)
new_text = old_text_multi
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI.items():
new_text = re.sub(pattern, replacement, new_text)
output_dict.update(dict(zip(old_text_multi.split("\n"), new_text.split("\n"))))
return output_dict
def load_original_state_dict(input_base_path):
model = AutoModel.from_pretrained(
input_base_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=False,
trust_remote_code=True,
).eval()
return model.state_dict()
def get_internvl_config(input_base_path):
base_config = AutoModel.from_pretrained(input_base_path, trust_remote_code=True).config
llm_config = base_config.llm_config.to_dict()
vision_config = base_config.vision_config.to_dict()
vision_config["use_absolute_position_embeddings"] = True
if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
image_token_id = 151667
language_config_class = Qwen2Config
else:
image_token_id = 92546
language_config_class = LlamaConfig
llm_config = {k: v for k, v in llm_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
# Force use_cache to True
llm_config["use_cache"] = True
# Force correct eos_token_id for InternVL3
if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
llm_config["eos_token_id"] = 151645
vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
if "attention_probs_dropout_prob" in vision_config:
attention_dropout = vision_config.pop("attention_probs_dropout_prob")
vision_config["attention_dropout"] = attention_dropout
vision_config["projection_dropout"] = attention_dropout
if "qk_normalization" in vision_config:
use_qk_norm = vision_config.pop("qk_normalization")
vision_config["use_qk_norm"] = use_qk_norm
if "qkv_bias" in vision_config:
attention_bias = vision_config.pop("qkv_bias")
vision_config["attention_bias"] = attention_bias
return InternVLConfig(
text_config=language_config_class(**llm_config),
vision_config=InternVLVisionConfig(**vision_config),
image_token_id=image_token_id,
)
def write_model(
model_path,
input_base_path,
push_to_hub=False,
hub_dir=None,
):
os.makedirs(model_path, exist_ok=True)
config = get_internvl_config(input_base_path)
config.architectures = ["InternVLForConditionalGeneration"]
config.save_pretrained(model_path)
if push_to_hub:
config.push_to_hub(hub_dir, use_temp_dir=True)
print("Model config saved successfully...")
# ------------------------------------------------------------
# Convert weights
# ------------------------------------------------------------
print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
state_dict_old = load_original_state_dict(input_base_path)
print("Converting model...")
all_keys = list(state_dict_old.keys())
new_keys = convert_old_keys_to_new_keys(all_keys, path=input_base_path)
lm_dim = config.text_config.hidden_size
dim = config.vision_config.hidden_size
state_dict = {}
for key in all_keys:
new_key = new_keys[key]
if "attn.qkv" in key:
new_key_query = new_key.replace("attention.qkv", "attention.q_proj")
state_dict[new_key_query] = state_dict_old[key][:dim]
new_key_key = new_key.replace("attention.qkv", "attention.k_proj")
state_dict[new_key_key] = state_dict_old[key][dim : 2 * dim]
new_key_value = new_key.replace("attention.qkv", "attention.v_proj")
state_dict[new_key_value] = state_dict_old[key][-dim:]
elif "attention.wqkv" in key:
num_key_value_groups = config.text_config.num_attention_heads // config.text_config.num_key_value_heads
head_dim = config.text_config.head_dim
wqkv_weights = state_dict_old[key]
qkv_vecs = rearrange(wqkv_weights, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim)
q_proj = qkv_vecs[:, :num_key_value_groups, ...].reshape(-1, lm_dim).contiguous()
k_proj = qkv_vecs[:, -2, ...].reshape(-1, lm_dim).contiguous()
v_proj = qkv_vecs[:, -1, ...].reshape(-1, lm_dim).contiguous()
new_key_query = new_key.replace("attention.wqkv", "self_attn.q_proj")
state_dict[new_key_query] = q_proj
new_key_key = new_key.replace("attention.wqkv", "self_attn.k_proj")
state_dict[new_key_key] = k_proj
new_key_value = new_key.replace("attention.wqkv", "self_attn.v_proj")
state_dict[new_key_value] = v_proj
else:
state_dict[new_key] = state_dict_old[key]
del state_dict_old
gc.collect()
print("Loading the checkpoint in a InternVLForConditionalGeneration model.")
model = InternVLForConditionalGeneration(config)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
model = model.to(torch.bfloat16)
print("model dtype:", model.dtype)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
print("Saving the model.")
model.save_pretrained(model_path)
if push_to_hub:
model.push_to_hub(hub_dir, use_temp_dir=True)
image_processor = GotOcr2ImageProcessorFast.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = InternVLProcessor(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
processor.save_pretrained(model_path)
if push_to_hub:
processor.push_to_hub(hub_dir, use_temp_dir=True)
# generation config
if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama":
print("Saving generation config...")
# in the original model, eos_token is not the same in the text_config and the generation_config
# ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config)
generation_config = GenerationConfig(
eos_token_id=92542,
)
generation_config.save_pretrained(model_path)
if push_to_hub:
generation_config.push_to_hub(hub_dir, use_temp_dir=True)
# del state_dict, model
# # Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
model = InternVLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
print("Model reloaded successfully.")
del model
def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None, hub_dir: str = None):
if LM_TYPE_CORRESPONDENCE[path] == "qwen2":
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
return_token_type_ids=False,
extra_special_tokens={
"start_image_token": "<img>",
"end_image_token": "</img>",
"context_image_token": "<IMG_CONTEXT>",
},
)
tokenizer.model_max_length = CONTEXT_LENGTH
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
"<img>",
"</img>",
"<IMG_CONTEXT>",
"<quad>",
"</quad>",
"<ref>",
"</ref>",
"<box>",
"</box>",
]
},
replace_additional_special_tokens=False,
)
else:
# Obtained with:
# tokenizer_llama_fast = LlamaTokenizerFast.from_pretrained(
# "OpenGVLab/InternVL2_5-2B-MPO", pad_token="</s>", legacy=False, from_slow=True
# )
# tokenizer_llama_fast._tokenizer.pre_tokenizer.prepend_scheme = "never"
# Then manually modifying `added_tokens_decoder` indices to match the original tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"./intern_vl_hf_implem/tokenizer_internvl_llama_fast",
return_token_type_ids=False,
extra_special_tokens={
"start_image_token": "<img>",
"end_image_token": "</img>",
"context_image_token": "<IMG_CONTEXT>",
},
)
tokenizer.chat_template = chat_template
tokenizer.save_pretrained(save_dir)
if push_to_hub:
tokenizer.push_to_hub(hub_dir, use_temp_dir=True)
def write_image_processor(save_dir: str, push_to_hub: bool = False, hub_dir: str = None):
image_processor = GotOcr2ImageProcessorFast(
do_resize=True,
size={"height": 448, "width": 448},
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
do_convert_rgb=True,
)
image_processor.save_pretrained(save_dir)
if push_to_hub:
image_processor.push_to_hub(hub_dir, use_temp_dir=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
default="OpenGVLab/InternVL3-1B",
help="Location of original InternVL model",
)
parser.add_argument(
"--output_dir",
default="InternVL3-1B-hf",
help="Location to write HF model and processors",
)
parser.add_argument(
"--hub_dir",
default="OpenGVLab/InternVL3-1B-hf",
help="Location to write HF model and processors",
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
write_tokenizer(
save_dir=args.output_dir,
push_to_hub=args.push_to_hub,
path=args.input_dir,
hub_dir=args.hub_dir,
)
write_image_processor(
save_dir=args.output_dir,
push_to_hub=args.push_to_hub,
hub_dir=args.hub_dir,
)
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
push_to_hub=args.push_to_hub,
hub_dir=args.hub_dir,
)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,708 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
torch_int,
)
from ..clip.modeling_clip import CLIPMLP
from ..janus.modeling_janus import JanusVisionAttention
from ..llama.modeling_llama import LlamaRMSNorm
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "OpenGVLab/InternVL3-1B-hf"
_CONFIG_VISION_FOR_DOC = "InternVLVisionConfig"
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = key
value_states = value
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# No upcasting of the attention weights to float32 in this implementation
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class InternVLVisionRMSNorm(LlamaRMSNorm):
pass
class InternVLVisionAttention(JanusVisionAttention):
def __init__(self, config: InternVLVisionConfig):
super().__init__()
del self.num_key_value_groups
# Needed for flash attention
self.is_causal = False
qk_norm = config.use_qk_norm
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
class InternVLVisionPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = InternVLVisionConfig
base_model_prefix = "internvl_vision"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["InternVLVisionLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, InternVLVisionEmbeddings):
module.cls_token.data.zero_()
if module.mask_token is not None:
module.mask_token.data.zero_()
if module.position_embeddings is not None:
module.position_embeddings.data.zero_()
elif isinstance(module, InternVLVisionLayer):
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
@dataclass
class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
"""
Class for outputs of [`InternVLVisionModel`].
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
*config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
will be returned.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
class InternVLVisionPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.patch_shape = patch_shape
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values)
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, (patch_height, patch_width)
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class InternVLVisionEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config: InternVLVisionConfig) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
if config.use_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
else:
self.mask_token = None
self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
self.patch_size = config.patch_size
self.image_size = (
config.image_size
if isinstance(config.image_size, collections.abc.Iterable)
else (config.image_size, config.image_size)
)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
else:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1 - w) + mask_tokens * w
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if self.position_embeddings is not None:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
embeddings = self.dropout(embeddings)
return embeddings, (patch_height, patch_width)
class InternVLVisionMLP(CLIPMLP):
pass
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
class InternVLVisionLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: InternVLVisionConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = InternVLVisionAttention(config)
self.mlp = InternVLVisionMLP(config)
# InternVL uses different layernorm implementations for different models
self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
init_values = config.layer_scale_init_value
self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
attention_output, attention_weights = self.attention(
self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
output_attentions=output_attentions,
)
attention_output = self.lambda_1 * attention_output
# first residual connection
hidden_states = attention_output + hidden_states
# in InternVLVision, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.mlp(layer_output)
layer_output = self.dropout(layer_output)
if self.lambda_2 is not None:
layer_output = self.lambda_2 * layer_output
# second residual connection
layer_output = layer_output + hidden_states
return layer_output, attention_weights
class InternVLVisionEncoder(nn.Module):
def __init__(self, config: InternVLVisionConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, output_attentions
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_CONFIG_VISION_FOR_DOC = "InternVLVisionConfig"
INTERNVL_VISION_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`InternVLVisionConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
INTERNVL_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`InternVLVisionImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare InternVLVision Model transformer outputting raw hidden-states without any specific head on top.",
INTERNVL_VISION_START_DOCSTRING,
)
class InternVLVisionModel(InternVLVisionPreTrainedModel):
def __init__(self, config: InternVLVisionConfig) -> None:
super().__init__(config)
self.config = config
self.embeddings = InternVLVisionEmbeddings(config)
self.encoder = InternVLVisionEncoder(config)
self.layernorm = (
nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_VISION_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=InternVLVisionModelOutputWithPooling,
config_class=_CONFIG_VISION_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
encoder_outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
return InternVLVisionModelOutputWithPooling(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
_CONFIG_FOR_DOC = "InternVLConfig"
class InternVLPreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
INTERNVL_INPUTS_DOCSTRING = None
class InternVLMultiModalProjector(nn.Module):
def __init__(self, config: InternVLConfig):
super().__init__()
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
def forward(self, image_features):
hidden_states = self.layer_norm(image_features)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
Args:
vision_features (`torch.Tensor`):
Input tensor of shape (batch_size, width, height, channels).
scale_factor (`float`, *optional*, defaults to `0.5`):
Factor by which to downsample. Default is 0.5, which halves the dimensions.
Returns:
vision_features (`torch.Tensor`):
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
"""
batch_size, width, height, channels = vision_features.size()
if height % scale_factor != 0 or width % scale_factor != 0:
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
# Reshape to allow downsampling
vision_features = vision_features.view(
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
)
# Permute dimensions to align downsampled axis correctly
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
# Reshape to achieve final downsampled dimensions
vision_features = vision_features.view(
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
)
# Swap height and width back for proper orientation
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
return vision_features
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`int` or `List[int]`):
Layer index or list of layer indices to extract features from.
Returns:
vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
"""
downsample_ratio = self.config.downsample_ratio
if vision_feature_layer == -1:
vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
else:
vision_features = self.vision_model(pixel_values=pixel_values).hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
vision_features = vision_features[:, 1:, :]
# Calculate dimensions based on vision features
channels = vision_features.shape[1]
feature_size = int(channels**0.5)
batch_size = vision_features.shape[0]
# Reshape tensor to spatial dimensions
vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
# Apply downsampling using pixel shuffle
vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
# Reshape tensor to prepare for projection
vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
# Project features through multi-modal projector
vision_features = self.multi_modal_projector(vision_features)
return vision_features
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
**lm_kwargs,
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> torch_device = "cuda"
>>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
>>> model = AutoModelForImageTextToText.from_pretrained(
... "OpenGVLab/InternVL3-1B-hf", torch_dtype=torch.bfloat16, device_map=torch_device
... )
>>> messages = [
... {
... "role": "user",
... "content": [
... {
... "type": "image",
... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
... },
... {
... "type": "image",
... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
... },
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
... ],
... },
... ]
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
>>> generate_ids = model.generate(**inputs, max_new_tokens=200)
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
The images depict the Statue of Liberty and the Golden Gate Bridge.
```"""
super().forward(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
__all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
"InternVLForConditionalGeneration",
]

View File

@ -0,0 +1,378 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.processing_utils import (
AllKwargsForChatTemplate,
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from ...image_processing_utils import BatchFeature
from ...image_utils import (
ImageInput,
VideoInput,
VideoMetadata,
concatenate_list,
make_batched_videos,
make_flat_list_of_images,
)
class InternVLImagesKwargs(ImagesKwargs, total=False):
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]
class InternVLProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: InternVLImagesKwargs
_defaults = {
"text_kwargs": {
"padding_side": "left",
},
"images_kwargs": {
"crop_to_patches": True,
},
"videos_kwargs": {
"crop_to_patches": False,
},
}
class InternVLProcessor(ProcessorMixin):
r"""
Constructs a InternVL processor which wraps a [`AutoImageProcessor`] and
[`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
tokenizer functionalities. See the [`~InternVLProcessor.__call__`] and [`~InternVLProcessor.decode`] for more information.
Args:
image_processor ([`AutoImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
The tokenizer is a required input.
image_seq_length (`int`, *optional*, defaults to 256):
The number of image token to use per image patch. it should be set so that:
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
The token to use for the image placeholder in the text. This token will be replaced by the
appropriate image tokens when processing the text with images.
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
The token to use for the video placeholder in the text. This token will be replaced by the
appropriate image tokens when processing the text with videos.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"image_seq_length",
"fake_image_token",
"fake_video_token",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
image_seq_length: int = 256,
chat_template=None,
fake_image_token="<image>",
fake_video_token="<video>",
**kwargs,
):
self.image_seq_length = image_seq_length
self.fake_image_token = fake_image_token
self.fake_video_token = fake_video_token
self.start_image_token = tokenizer.start_image_token
self.end_image_token = tokenizer.end_image_token
self.context_image_token = tokenizer.context_image_token
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
def _insert_media_placeholders(
self,
text: list[str],
image_pixel_values,
video_pixel_values,
image_num_patches: list[int],
video_num_patches: list[int],
image_num_patches_indices: np.ndarray,
video_num_patches_indices: np.ndarray,
video_patch_indices: np.ndarray,
):
"""
Processes interleaved text with <image> and <video> placeholders, replacing them with appropriate
image and video tokens while keeping track of the patches used.
"""
image_index = 0
video_index = 0
processed_text = []
image_video_patches = []
# Support interleaved image and video in prompts:
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
for prompt in text:
new_prompt = prompt
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt:
if self.fake_image_token in new_prompt and (
self.fake_video_token not in new_prompt
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token)
):
# Get the slice of patches corresponding to the current image
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
end_index = image_num_patches_indices[image_index]
image_video_patches.append(image_pixel_values[start_index:end_index])
# Replace the corresponding image placeholder with the correct number of image tokens
new_prompt = new_prompt.replace(
self.fake_image_token,
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}",
1,
)
image_index += 1
else:
# Get the slice of patches corresponding to the current video
# Here we need to account for both the multiple video frames and the potential multiple patches per frame
# As of now, InternVL only supports one patch per frame, but we keep the code flexible for future updates
current_patch_index = video_patch_indices[video_index - 1] if video_index > 0 else 0
end_patch_index = video_patch_indices[video_index]
start_index = video_num_patches_indices[current_patch_index] if video_index > 0 else 0
end_index = video_num_patches_indices[end_patch_index - 1]
image_video_patches.append(video_pixel_values[start_index:end_index])
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
video_prompt = "\n".join(
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
for i in range(len(num_patches))
)
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1)
video_index += 1
processed_text.append(new_prompt)
return processed_text, image_video_patches, image_index, video_index
def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos: Optional[VideoInput] = None,
**kwargs: Unpack[InternVLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
`crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to
GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if text is None:
raise ValueError("You have to specify text.")
output_kwargs = self._merge_kwargs(
InternVLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if not isinstance(text, (list, tuple)):
text = [text]
# Process images and videos separately, as videos don't support crop_to_patches
image_num_patches = []
video_num_patches = []
image_videos_inputs = {}
image_pixel_values = None
video_pixel_values = None
image_num_patches_indices = np.array([0])
video_patch_indices = np.array([0])
video_num_patches_indices = np.array([0])
if images is not None:
images = make_flat_list_of_images(images)
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_num_patches = image_inputs.pop("num_patches")
image_pixel_values = image_inputs.pop("pixel_values")
image_num_patches_indices = np.cumsum(image_num_patches)
if videos is not None:
videos = make_batched_videos(videos)
num_frames_per_video = [len(video) for video in videos]
video_patch_indices = np.cumsum(num_frames_per_video)
output_kwargs["images_kwargs"]["crop_to_patches"] = False
video_inputs = self.image_processor(images=videos, **output_kwargs["videos_kwargs"])
video_num_patches = video_inputs.pop("num_patches")
video_pixel_values = video_inputs.pop("pixel_values")
video_num_patches_indices = np.cumsum(video_num_patches)
if images is not None or videos is not None:
text, image_video_patches, image_index, video_index = self._insert_media_placeholders(
text,
image_pixel_values,
video_pixel_values,
image_num_patches,
video_num_patches,
image_num_patches_indices,
video_num_patches_indices,
video_patch_indices,
)
if images is not None and image_index != len(images):
raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
if videos is not None and video_index != len(videos):
raise ValueError("Number of video placeholders in the prompt does not match the number of videos.")
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_videos_inputs})
def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
):
"""
The function to generate indices of frames to sample from a video.
Args:
metadata (`VideoMetadata`):
`VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If None, all frames are sampled.
initial_shift (`bool`, `float` or `int`, defaults to `0`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
Returns:
`np.ndarray`: Array of frame indices to sample.
"""
if initial_shift is True:
initial_shift = metadata.total_num_frames / num_frames / 2
if num_frames is not None:
indices = np.arange(
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames
).astype(int)
else:
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
return indices
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names) + list(image_processor_input_names)
# Add model-specific video sampling method when applying the template
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
chat_template: Optional[str] = None,
num_frames: int = 8,
initial_shift: Union[bool, float, int] = True,
video_load_backend="pyav",
**kwargs: Unpack[AllKwargsForChatTemplate],
):
"""
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
conversations to turn them into a single tokenizable string.
The input is expected to be in the following format, where each message content is a list consisting of text and
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
]
Args:
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
chat template is used.
num_frames (`int`, *optional*, defaults to 8):
Number of frames to sample from a video when using the default `sample_indices_fn`.
initial_shift (`bool`, `float` or `int`, defaults to `0`):
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
If `True`, the shift is set so that frames are sampled from the middle of the video.
"""
sample_indices_fn = kwargs.pop(
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
)
return super().apply_chat_template(
conversation,
chat_template,
video_load_backend=video_load_backend,
num_frames=num_frames,
sample_indices_fn=sample_indices_fn,
**kwargs,
)
__all__ = ["InternVLProcessor"]

View File

@ -130,6 +130,7 @@ VLM_CLASS_NAMES = [
"gemma3",
"mistral3",
"chameleon",
"internvl",
"qwen2_5_omni",
]

View File

View File

@ -0,0 +1,894 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch InternVL model."""
import unittest
from io import BytesIO
import requests
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
InternVLConfig,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
cleanup,
require_av,
require_bitsandbytes,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
InternVLForConditionalGeneration,
)
if is_vision_available():
from PIL import Image
class InternVLVisionText2TextModelTester:
def __init__(
self,
parent,
batch_size=3,
seq_length=7,
image_seq_length=64,
vision_feature_layer=-1,
ignore_index=-100,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_id=1,
num_channels=3,
image_size=64,
model_type="internvl",
is_training=True,
text_config={
"model_type": "qwen2",
"vocab_size": 99,
"hidden_size": 128,
"intermediate_size": 37,
"num_hidden_layers": 4,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"output_channels": 64,
"hidden_act": "silu",
"max_position_embeddings": 512,
"rope_theta": 10000,
"mlp_ratio": 4,
"tie_word_embeddings": True,
"bos_token_id": 0,
"eos_token_id": 0,
"pad_token_id": 0,
},
vision_config={
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 128,
"image_size": 64,
"patch_size": 4,
"num_channels": 3,
"hidden_act": "quick_gelu",
"use_absolute_position_embeddings": True,
},
):
self.parent = parent
self.ignore_index = ignore_index
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.image_token_id = image_token_id
self.model_type = model_type
self.text_config = text_config
self.vision_config = vision_config
self.batch_size = batch_size
self.vision_feature_layer = vision_feature_layer
self.is_training = is_training
self.image_seq_length = image_seq_length
self.num_channels = num_channels
self.image_size = image_size
self.seq_length = seq_length + image_seq_length
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
def get_config(self):
return InternVLConfig(
text_config=self.text_config,
vision_config=self.vision_config,
model_type=self.model_type,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
image_token_id=self.image_token_id,
image_seq_length=self.image_seq_length,
vision_feature_layer=self.vision_feature_layer,
)
def prepare_config_and_inputs(self):
config = self.get_config()
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
# input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
model = InternVLForConditionalGeneration(config=config)
model.to(torch_device)
model.half()
model.eval()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask):
config.torch_dtype = torch.float16
model = InternVLForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch
class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"image-text-to-text": InternVLForConditionalGeneration,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
def setUp(self):
self.model_tester = InternVLVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=InternVLConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip("Qwen2 flash attention does not support right padding")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
@slow
@require_torch_gpu
class InternVLQwen2IntegrationTest(unittest.TestCase):
def setUp(self):
self.small_model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
self.medium_model_checkpoint = "OpenGVLab/InternVL3-2B-hf"
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_qwen2_small_model_integration_generate(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
self.assertEqual(decoded_output, expected_output)
def test_qwen2_small_model_integration_forward(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
# Forward
with torch.inference_mode():
output = model(**inputs)
actual_logits = output.logits[0, -1, :5].cpu()
expected_logits = torch.tensor([11.9375, 14.8750, 14.0625, 10.7500, 6.9062], dtype=torch.bfloat16)
self.assertTrue(
torch.allclose(actual_logits, expected_logits, atol=0.1),
f"Actual logits: {actual_logits}"
f"\nExpected logits: {expected_logits}"
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
)
def test_qwen2_small_model_integration_generate_text_only(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
prompt = "<|im_start|>user\nWrite a haiku<|im_end|>\n<|im_start|>assistant\n"
inputs = processor(text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."
self.assertEqual(decoded_output, expected_output)
def test_qwen2_small_model_integration_generate_chat_template(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{"type": "text", "text": "Please describe the image explicitly."},
],
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
self.assertEqual(decoded_output, expected_output)
def test_qwen2_small_model_integration_batched_generate(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
torch_device, dtype=torch.bfloat16
)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
def test_qwen2_small_model_integration_batched_generate_multi_image(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(
BytesIO(
requests.get(
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
).content
)
)
image3 = Image.open(
BytesIO(
requests.get(
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
).content
)
)
inputs = processor(text=prompt, images=[[image1], [image2, image3]], padding=True, return_tensors="pt").to(
torch_device, dtype=torch.bfloat16
)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Angle' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
@require_av
@require_bitsandbytes
def test_qwen2_medium_model_integration_video(self):
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = InternVLForConditionalGeneration.from_pretrained(
self.medium_model_checkpoint, quantization_config=quantization_config
)
# Prepare inputs
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
},
{"type": "text", "text": "What type of shot is the man performing?"},
],
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(torch_device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
expected_output = 'The man is performing a forehand shot.' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
@require_av
def test_qwen2_small_model_integration_interleaved_images_videos(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, torch_dtype=torch.bfloat16, device_map=torch_device
)
messages = [
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
},
{"type": "text", "text": "What are the differences between these two images?"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
},
{"type": "text", "text": "What type of shot is the man performing?"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://llava-vl.github.io/static/images/view.jpg",
},
{"type": "text", "text": "Write a haiku for this image"},
],
}
],
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
expected_output = 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image**: This shows the Statue of Liberty on Liberty Island, with the' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check third output
decoded_output = processor.decode(output[2], skip_special_tokens=True)
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
@slow
@require_torch_gpu
class InternVLLlamaIntegrationTest(unittest.TestCase):
def setUp(self):
self.small_model_checkpoint = "OpenGVLab/InternVL2_5-2B-MPO-hf"
self.medium_model_checkpoint = "OpenGVLab/InternVL2_5-8B-MPO-hf"
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_llama_small_model_integration_generate(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "The image shows two cats sleeping on a pink couch. They are lying side by side, with their"
self.assertEqual(decoded_output, expected_output)
def test_llama_small_model_integration_forward(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
# Forward
with torch.inference_mode():
output = model(**inputs)
actual_logits = output.logits[0, -1, :5].cpu()
expected_logits = torch.tensor([-9.8750, -0.4258, 1.4844, -10.3125, -10.3125], dtype=torch.bfloat16)
# The original implementation and the transformers implementation do not match exactly, hence the higher tolerance.
# The difference is likely due to the different implementations of the attention mechanism (different order of operations)
# between the transformers Llama model and the original InternLM model.
# The difference has almost no effect on the output tokens, but it does affect the logits a lot more.
self.assertTrue(
torch.allclose(actual_logits, expected_logits, atol=1),
f"Actual logits: {actual_logits}"
f"\nExpected logits: {expected_logits}"
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
)
def test_llama_small_model_integration_generate_text_only(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
prompt = "<|im_start|>user\nWrite a haiku<|im_end|>\n<|im_start|>assistant\n"
inputs = processor(text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "Autumn leaves fall,\nNature's breath, a season's sigh,\nSilent woods awake."
self.assertEqual(decoded_output, expected_output)
def test_llama_small_model_integration_generate_chat_template(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{"type": "text", "text": "Please describe the image explicitly."},
],
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(torch_device, dtype=torch.bfloat16)
with torch.no_grad():
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
decoded_output = processor.decode(
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
expected_output = "The image shows two cats sleeping on a pink couch. They are lying side by side, with their"
self.assertEqual(decoded_output, expected_output)
def test_llama_small_model_integration_batched_generate(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
torch_device, dtype=torch.bfloat16
)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese gate in the background, adorned with red and gold colors and Chinese characters' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
def test_llama_small_model_integration_batched_generate_multi_image(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
)
# Prepare inputs
prompt = [
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(
BytesIO(
requests.get(
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
).content
)
)
image3 = Image.open(
BytesIO(
requests.get(
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
).content
)
)
inputs = processor(text=prompt, images=[[image1], [image2, image3]], padding=True, return_tensors="pt").to(
torch_device, dtype=torch.bfloat16
)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
# Check first output
decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, still waters.' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After closely examining the images again, I can see that there are several differences' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
@require_av
@require_bitsandbytes
def test_llama_medium_model_integration_video(self):
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = InternVLForConditionalGeneration.from_pretrained(
self.medium_model_checkpoint, quantization_config=quantization_config
)
# Prepare inputs
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
},
{"type": "text", "text": "What type of shot is the man performing?"},
],
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(torch_device, dtype=torch.float16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
expected_output = "The man is performing a forehand shot."
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
@require_av
def test_llama_small_model_integration_interleaved_images_videos(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, torch_dtype=torch.bfloat16, device_map=torch_device
)
messages = [
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
},
{"type": "text", "text": "What are the difference between these two images?"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
},
{"type": "text", "text": "What type of shot is the man performing?"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://llava-vl.github.io/static/images/view.jpg",
},
{"type": "text", "text": "Write a haiku for this image"},
],
}
],
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
).to(torch_device, dtype=torch.bfloat16)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
decoded_output = processor.decode(output[0], skip_special_tokens=True)
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
expected_output = 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check second output
decoded_output = processor.decode(output[1], skip_special_tokens=True)
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot. This is a common shot in tennis where the player swings the racket across their' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)
# Check third output
decoded_output = processor.decode(output[2], skip_special_tokens=True)
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, untouched dreams.' # fmt: skip
self.assertEqual(
decoded_output,
expected_output,
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
)

View File

@ -0,0 +1,327 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import shutil
import tempfile
import unittest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_torch_available():
import torch
if is_vision_available():
from transformers import GotOcr2ImageProcessor
@require_vision
class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = InternVLProcessor
videos_input_name = "pixel_values"
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
image_processor = GotOcr2ImageProcessor(
do_resize=True,
size={"height": 20, "width": 20},
max_patches=2,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
do_center_crop=True,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
do_convert_rgb=True,
)
tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B-hf", padding_side="left")
processor_kwargs = cls.prepare_processor_dict()
processor = InternVLProcessor.from_pretrained(
"OpenGVLab/InternVL3-1B-hf",
image_processor=image_processor,
tokenizer=tokenizer,
**processor_kwargs,
)
processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.fake_image_token
@staticmethod
def prepare_processor_dict():
return {"image_seq_length": 10}
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def get_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
@require_av
@require_torch
def test_process_interleaved_images_videos(self):
processor = self.get_processor()
messages = [
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
},
{"type": "text", "text": "What are the differences between these two images?"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
},
{"type": "text", "text": "What type of shot is the man performing?"},
],
},
],
[
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://llava-vl.github.io/static/images/view.jpg",
},
{"type": "text", "text": "Write a haiku for this image"},
],
}
],
]
inputs_batched = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
)
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
images_patches_index = 0
for i, message in enumerate(messages):
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
)
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
torch.testing.assert_close(
inputs["input_ids"][0], inputs_batched["input_ids"][i][-inputs["input_ids"].shape[1] :]
)
torch.testing.assert_close(
inputs["pixel_values"],
inputs_batched["pixel_values"][
images_patches_index : images_patches_index + inputs["pixel_values"].shape[0]
],
)
images_patches_index += inputs["pixel_values"].shape[0]
# Override video chat_template tests as InternVLProcessor returns flattened video features
@require_av
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
def test_apply_chat_template_video_frame_sampling(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
messages = [
[
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
num_frames = 3
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
num_frames=num_frames,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
# Load with `video_fps` arg
video_fps = 1
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=None, # force to use default num_frames
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
# Load with `video_fps` and `num_frames` args, should raise an error
with self.assertRaises(ValueError):
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=num_frames,
)
# Load without any arg should use the default loading method
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
# because we assume they come from one video
messages[0][0]["content"][0] = {
"type": "video",
"url": [
"https://www.ilankelman.org/stopsigns/australia.jpg",
"https://www.ilankelman.org/stopsigns/australia.jpg",
],
}
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2)

View File

@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
"InternVLVisionModel", # Building part of bigger (tested) model
"JanusVisionModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model
]