mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add InternVL (2.5 MPO) (#35968)
* initial commit * add convert internvl * add first end-to-end working internvl * nit prompt and image proc * add working chat template * add conversion llama-based models * add tests * pass all tests * fix isort * fix modular after main merge * add video processing for internvl * add support for interlaced images and videos * Remove processing and config from modular, add more tests * add llama model tests * Modify processor for compatibility with refactored got ocr image processor * add comments in processor * Add docs and nits * change video processing to use custom sample_indices_fn * rebase and fix tests * add processor tests * Add changes Raushan review * Use the new attention interface for the vision model * nits * add support for custom video_load_backend * remove mention to InternVLTokenizer * refactor vision model to simplify logic * refactor processor for better readibility * fix copies * fix require av processor test * refactor internVL vision * Update processor and fix processing tests * fix docstring * update convert_weights for internvl3 * change image processor to fast by default * remove do_center_crop=True in convert_weights * force use_cache to True * push_to_hub before reloading * fix internVLVision for larger models * update convert weight for qk norm * fix convert_weights * fix eos_token_id in convert * update docs and integration tests * make modifs after review * fix wrong k_norm and reduce modular * change image_token_index to image_token_id * change checkpoint to OpenGVLab org * last nits * explicitely del self.num_key_value_groups * add extra special tokens
This commit is contained in:
parent
b0c6ff5e13
commit
a245011252
@ -953,6 +953,8 @@
|
||||
title: InstructBLIP
|
||||
- local: model_doc/instructblipvideo
|
||||
title: InstructBlipVideo
|
||||
- local: model_doc/internvl
|
||||
title: InternVL
|
||||
- local: model_doc/janus
|
||||
title: Janus
|
||||
- local: model_doc/kosmos-2
|
||||
|
349
docs/source/en/model_doc/internvl.md
Normal file
349
docs/source/en/model_doc/internvl.md
Normal file
@ -0,0 +1,349 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# InternVL
|
||||
|
||||
The InternVL3 family of Visual Language Models was introduced in [InternVL3: Exploring Advanced Training and Test-Time Recipes for Open-Source Multimodal Models](https://huggingface.co/papers/2504.10479).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce InternVL3, a significant advancement in the InternVL series featuring a native multimodal pre-training paradigm. Rather than adapting a text-only large language model (LLM) into a multimodal large language model (MLLM) that supports visual inputs, InternVL3 jointly acquires multimodal and linguistic capabilities from both diverse multimodal data and pure-text corpora during a single pre-training stage. This unified training paradigm effectively addresses the complexities and alignment challenges commonly encountered in conventional post-hoc training pipelines for MLLMs. To further improve performance and scalability, InternVL3 incorporates variable visual position encoding (V2PE) to support extended multimodal contexts, employs advanced post-training techniques such as supervised fine-tuning (SFT) and mixed preference optimization (MPO), and adopts test-time scaling strategies alongside an optimized training infrastructure. Extensive empirical evaluations demonstrate that InternVL3 delivers superior performance across a wide range of multi-modal tasks. In particular, InternVL3-78B achieves a score of 72.2 on the MMMU benchmark, setting a new state-of-the-art among open-source MLLMs. Its capabilities remain highly competitive with leading proprietary models, including ChatGPT-4o, Claude 3.5 Sonnet, and Gemini 2.5 Pro, while also maintaining strong pure-language proficiency. In pursuit of open-science principles, we will publicly release both the training data and model weights to foster further research and development in next-generation MLLMs.*
|
||||
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/internvl_architecture.png" alt="drawing" width="600"/>
|
||||
|
||||
<small> Overview of InternVL3 models architecture, which is the same as InternVL2.5. Taken from the <a href="https://huggingface.co/OpenGVLab/InternVL3-1B">original checkpoint.</a> </small>
|
||||
|
||||
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/internvl_overview_performance.png" alt="drawing" width="600"/>
|
||||
|
||||
<small> Comparison of InternVL3 performance on OpenCompass against other SOTA VLLMs. Taken from the <a href="https://huggingface.co/OpenGVLab/InternVL3-1B">original checkpoint.</a> </small>
|
||||
|
||||
|
||||
|
||||
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
|
||||
The original code can be found [here](https://github.com/OpenGVLab/InternVL).
|
||||
|
||||
## Usage example
|
||||
|
||||
### Inference with Pipeline
|
||||
|
||||
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `InternVL3` models in just a few lines of code:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {
|
||||
... "type": "image",
|
||||
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||
... },
|
||||
... {"type": "text", "text": "Describe this image."},
|
||||
... ],
|
||||
... },
|
||||
... ]
|
||||
|
||||
>>> pipe = pipeline("image-text-to-text", model="OpenGVLab/InternVL3-1B-hf")
|
||||
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
||||
>>> outputs[0]["generated_text"]
|
||||
'The image showcases a vibrant scene of nature, featuring several flowers and a bee. \n\n1. **Foreground Flowers**: \n - The primary focus is on a large, pink cosmos flower with a prominent yellow center. The petals are soft and slightly r'
|
||||
```
|
||||
### Inference on a single image
|
||||
|
||||
This example demonstrates how to perform inference on a single image with the InternVL models using chat templates.
|
||||
|
||||
> [!NOTE]
|
||||
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||
... {"type": "text", "text": "Please describe the image explicitly."},
|
||||
... ],
|
||||
... }
|
||||
... ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=50)
|
||||
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
|
||||
>>> decoded_output
|
||||
'The image shows two cats lying on a pink blanket. The cat on the left is a tabby with a mix of brown, black, and white fur, and it appears to be sleeping with its head resting on the blanket. The cat on the'
|
||||
```
|
||||
|
||||
### Text-only generation
|
||||
This example shows how to generate text using the InternVL model without providing any image input.
|
||||
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "text", "text": "Write a haiku"},
|
||||
... ],
|
||||
... }
|
||||
... ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=50)
|
||||
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
|
||||
>>> print(decoded_output)
|
||||
"Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."
|
||||
```
|
||||
|
||||
### Batched image and text inputs
|
||||
InternVL models also support batched image and text inputs.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
... {"type": "text", "text": "Describe this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... ]
|
||||
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
["user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace.",
|
||||
'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of']
|
||||
```
|
||||
|
||||
### Batched multi-image input
|
||||
This implementation of the InternVL models supports batched text-images inputs with different number of images for each text.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
>>> ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
["user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace.",
|
||||
'user\n\n\nThese images depict two different landmarks. Can you identify them?\nassistant\nYes, these images depict the Statue of Liberty and the Golden Gate Bridge.']
|
||||
```
|
||||
|
||||
### Video input
|
||||
InternVL models can also handle video inputs. Here is an example of how to perform inference on a video input using chat templates.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
|
||||
>>> model_checkpoint = "OpenGVLab/InternVL3-8B-hf"
|
||||
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, quantization_config=quantization_config)
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {
|
||||
... "type": "video",
|
||||
... "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
... },
|
||||
... {"type": "text", "text": "What type of shot is the man performing?"},
|
||||
... ],
|
||||
... }
|
||||
>>> ]
|
||||
>>> inputs = processor.apply_chat_template(
|
||||
... messages,
|
||||
... return_tensors="pt",
|
||||
... add_generation_prompt=True,
|
||||
... tokenize=True,
|
||||
... return_dict=True,
|
||||
>>> ).to(model.device, dtype=torch.float16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
>>> decoded_output
|
||||
'The man is performing a forehand shot.'
|
||||
```
|
||||
|
||||
### Interleaved image and video inputs
|
||||
This example showcases how to handle a batch of chat conversations with interleaved image and video inputs using chat template.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "video", "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"},
|
||||
... {"type": "text", "text": "What type of shot is the man performing?"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
>>> ]
|
||||
>>> inputs = processor.apply_chat_template(
|
||||
... messages,
|
||||
... padding=True,
|
||||
... add_generation_prompt=True,
|
||||
... tokenize=True,
|
||||
... return_dict=True,
|
||||
... return_tensors="pt",
|
||||
>>> ).to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
['user\n\n\nThese images depict two different landmarks. Can you identify them?\nassistant\nThe images depict the Statue of Liberty and the Golden Gate Bridge.',
|
||||
'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot',
|
||||
"user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace."]
|
||||
```
|
||||
|
||||
## InternVLVisionConfig
|
||||
|
||||
[[autodoc]] InternVLVisionConfig
|
||||
|
||||
## InternVLConfig
|
||||
|
||||
[[autodoc]] InternVLConfig
|
||||
|
||||
## InternVLVisionModel
|
||||
|
||||
[[autodoc]] InternVLVisionModel
|
||||
- forward
|
||||
|
||||
## InternVLForConditionalGeneration
|
||||
|
||||
[[autodoc]] InternVLForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## InternVLProcessor
|
||||
|
||||
[[autodoc]] InternVLProcessor
|
@ -244,7 +244,7 @@ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_m
|
||||
|
||||
### Benchmarks
|
||||
|
||||
FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn't support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.
|
||||
FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn't support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.
|
||||
|
||||
<hfoptions id="padded">
|
||||
<hfoption id="short sequence length">
|
||||
|
@ -18,7 +18,7 @@ from collections.abc import Iterable
|
||||
from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
@ -77,9 +77,8 @@ if is_vision_available():
|
||||
pil_torch_interpolation_mapping = {}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -162,6 +161,15 @@ def is_valid_list_of_images(images: list):
|
||||
return images and all(is_valid_image(image) for image in images)
|
||||
|
||||
|
||||
def concatenate_list(input_list):
|
||||
if isinstance(input_list[0], list):
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
elif isinstance(input_list[0], np.ndarray):
|
||||
return np.concatenate(input_list, axis=0)
|
||||
elif isinstance(input_list[0], torch.Tensor):
|
||||
return torch.cat(input_list, dim=0)
|
||||
|
||||
|
||||
def valid_images(imgs):
|
||||
# If we have an list of images, make sure every image is valid
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
|
@ -143,6 +143,7 @@ if TYPE_CHECKING:
|
||||
from .informer import *
|
||||
from .instructblip import *
|
||||
from .instructblipvideo import *
|
||||
from .internvl import *
|
||||
from .jamba import *
|
||||
from .janus import *
|
||||
from .jetmoe import *
|
||||
|
@ -162,6 +162,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("informer", "InformerConfig"),
|
||||
("instructblip", "InstructBlipConfig"),
|
||||
("instructblipvideo", "InstructBlipVideoConfig"),
|
||||
("internvl", "InternVLConfig"),
|
||||
("internvl_vision", "InternVLVisionConfig"),
|
||||
("jamba", "JambaConfig"),
|
||||
("janus", "JanusConfig"),
|
||||
("jetmoe", "JetMoeConfig"),
|
||||
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("informer", "Informer"),
|
||||
("instructblip", "InstructBLIP"),
|
||||
("instructblipvideo", "InstructBlipVideo"),
|
||||
("internvl", "InternVL"),
|
||||
("internvl_vision", "InternVLVision"),
|
||||
("jamba", "Jamba"),
|
||||
("janus", "Janus"),
|
||||
("jetmoe", "JetMoe"),
|
||||
@ -797,6 +801,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("chinese_clip_vision_model", "chinese_clip"),
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
("granitevision", "llava_next"),
|
||||
("internvl_vision", "internvl"),
|
||||
("qwen2_5_vl_text", "qwen2_5_vl"),
|
||||
("qwen2_vl_text", "qwen2_vl"),
|
||||
("sam_vision_model", "sam"),
|
||||
|
@ -151,6 +151,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("ijepa", "IJepaModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("informer", "InformerModel"),
|
||||
("internvl_vision", "InternVLVisionModel"),
|
||||
("jamba", "JambaModel"),
|
||||
("janus", "JanusModel"),
|
||||
("jetmoe", "JetMoeModel"),
|
||||
@ -862,6 +863,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("idefics2", "Idefics2ForConditionalGeneration"),
|
||||
("idefics3", "Idefics3ForConditionalGeneration"),
|
||||
("instructblip", "InstructBlipForConditionalGeneration"),
|
||||
("internvl", "InternVLForConditionalGeneration"),
|
||||
("janus", "JanusForConditionalGeneration"),
|
||||
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
|
@ -75,6 +75,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("idefics3", "Idefics3Processor"),
|
||||
("instructblip", "InstructBlipProcessor"),
|
||||
("instructblipvideo", "InstructBlipVideoProcessor"),
|
||||
("internvl", "InternVLProcessor"),
|
||||
("janus", "JanusProcessor"),
|
||||
("kosmos-2", "Kosmos2Processor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
|
@ -258,6 +258,7 @@ else:
|
||||
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"jamba",
|
||||
(
|
||||
|
28
src/transformers/models/internvl/__init__.py
Normal file
28
src/transformers/models/internvl/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_internvl import *
|
||||
from .modeling_internvl import *
|
||||
from .processing_internvl import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
225
src/transformers/models/internvl/configuration_internvl.py
Normal file
225
src/transformers/models/internvl/configuration_internvl.py
Normal file
@ -0,0 +1,225 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
class InternVLVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternVLVisionModel`]. It is used to instantiate an InternVLVisionModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
|
||||
a similar configuration to that of the InternVL3-1B.
|
||||
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
use_qk_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to apply normalization to the queries and keys before the attention operation.
|
||||
intermediate_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability for attention weights.
|
||||
projection_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability for the projection layer.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The type of normalization to use in the encoder. Can be `"layer_norm"` or `"rms_norm"`.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
image_size (`int` or `list[int]`, *optional*, defaults to `[448, 448]`):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int` or `list[int]`, *optional*, defaults to `[14, 14]`):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
use_mask_token (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a mask token for masked image modeling.
|
||||
use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use BERT-style absolute position embeddings.
|
||||
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
|
||||
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
|
||||
use_mean_pooling (`bool`, *optional*, defaults to `True`):
|
||||
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
||||
CLS token, before applying the classification head.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import InternVLVisionConfig, InternVLVisionModel
|
||||
|
||||
>>> # Initializing a InternVLVisionModel OpenGVLab/InternVL3-1B-hf style configuration
|
||||
>>> configuration = InternVLVisionConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
|
||||
>>> model = InternVLVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "internvl_vision"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
attention_bias=False,
|
||||
use_qk_norm=False,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_dropout=0.0,
|
||||
projection_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
norm_type="layer_norm",
|
||||
layer_norm_eps=1e-06,
|
||||
image_size=[448, 448],
|
||||
patch_size=[14, 14],
|
||||
num_channels=3,
|
||||
use_mask_token=False,
|
||||
use_absolute_position_embeddings=True,
|
||||
layer_scale_init_value=0.1,
|
||||
use_mean_pooling=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_bias = attention_bias
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_dropout = attention_dropout
|
||||
self.projection_dropout = projection_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.norm_type = norm_type
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
|
||||
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.use_mask_token = use_mask_token
|
||||
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.use_mean_pooling = use_mean_pooling
|
||||
|
||||
|
||||
class InternVLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternVLForConditionalGeneration`]. It is used to instantiate a
|
||||
InternVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of InternVL3-1B.
|
||||
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `InternVisonConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
||||
The config object or dictionary of the text backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151667):
|
||||
The image token index to encode the image prompt.
|
||||
image_seq_length (`int`, *optional*, defaults to 256):
|
||||
Number of image tokens to use per image patch.
|
||||
downsample_ratio (`float`, *optional*, defaults to 0.5):
|
||||
Factor by which to downsample the image.
|
||||
projector_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the projector.
|
||||
vision_feature_layer (`int`, *optional*, defaults to -1):
|
||||
The index of the layer to use as the image features.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`.
|
||||
|
||||
```python
|
||||
>>> from transformers import InternVLForConditionalGeneration, InternVLConfig
|
||||
|
||||
>>> # Initializing a InternVL style configuration
|
||||
>>> configuration = InternVLConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
|
||||
>>> model = InternVLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "internvl"
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": InternVLVisionConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
image_token_id=151667,
|
||||
image_seq_length=256,
|
||||
downsample_ratio=0.5,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_layer=-1,
|
||||
vision_feature_select_strategy="default",
|
||||
**kwargs,
|
||||
):
|
||||
self.image_token_id = image_token_id
|
||||
self.image_seq_length = image_seq_length
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = InternVLVisionConfig(**vision_config)
|
||||
elif isinstance(vision_config, InternVLVisionConfig):
|
||||
self.vision_config = vision_config
|
||||
elif vision_config is None:
|
||||
self.vision_config = InternVLVisionConfig()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["qwen2"]()
|
||||
|
||||
self.text_config = text_config
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["InternVLVisionConfig", "InternVLConfig"]
|
@ -0,0 +1,417 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
GotOcr2ImageProcessorFast,
|
||||
InternVLConfig,
|
||||
InternVLForConditionalGeneration,
|
||||
InternVLProcessor,
|
||||
InternVLVisionConfig,
|
||||
LlamaConfig,
|
||||
Qwen2Config,
|
||||
)
|
||||
|
||||
|
||||
LM_TYPE_CORRESPONDENCE = {
|
||||
"OpenGVLab/InternVL2_5-1B-MPO": "qwen2",
|
||||
"OpenGVLab/InternVL2_5-2B-MPO": "llama",
|
||||
"OpenGVLab/InternVL2_5-4B-MPO": "qwen2",
|
||||
"OpenGVLab/InternVL2_5-8B-MPO": "llama",
|
||||
"OpenGVLab/InternVL2_5-26B-MPO": "llama",
|
||||
"OpenGVLab/InternVL2_5-38B-MPO": "qwen2",
|
||||
"OpenGVLab/InternVL2_5-78B-MPO": "qwen2",
|
||||
"OpenGVLab/InternVL3-1B": "qwen2",
|
||||
"OpenGVLab/InternVL3-2B": "qwen2",
|
||||
"OpenGVLab/InternVL3-8B": "qwen2",
|
||||
"OpenGVLab/InternVL3-9B": "llama",
|
||||
"OpenGVLab/InternVL3-14B": "qwen2",
|
||||
"OpenGVLab/InternVL3-38B": "qwen2",
|
||||
"OpenGVLab/InternVL3-78B": "qwen2",
|
||||
}
|
||||
|
||||
UNNECESSARY_CONFIG_KEYS = [ "_name_or_path", "_attn_implementation_autoset", "auto_map", "use_bfloat16", "use_flash_attn", "bias", "laux_allreduce", "moe_coeff_ratio", "moe_intermediate_size", "moe_output_scale", "noisy_gate_policy", "shared_expert_intermediate_size", "use_residual", "use_moe", "use_rts", "use_weighted_residual", "moe_config", "num_experts", "num_routed_experts", "num_shared_experts", "capacity_factor", "eval_capacity_factor", "drop_path_rate"] # fmt: skip
|
||||
|
||||
# fmt: off
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION = {
|
||||
# Vision encoder mapping
|
||||
r"vision_model": r"vision_tower",
|
||||
r"layers": r"layer",
|
||||
r"class_embedding": r"cls_token",
|
||||
r"position_embedding": r"position_embeddings",
|
||||
r"patch_embedding": r"patch_embeddings.projection",
|
||||
r"ls(\d+)": r"lambda_\1",
|
||||
r"attn.proj": r"attention.projection_layer",
|
||||
r"attn.dropout": r"attention.projection_dropout",
|
||||
r"attn": r"attention",
|
||||
r"norm1": r"layernorm_before",
|
||||
r"norm2": r"layernorm_after",
|
||||
|
||||
}
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA = {
|
||||
# Vision encoder mapping
|
||||
r"tok_embeddings": r"embed_tokens",
|
||||
r"attention.wo": r"self_attn.o_proj",
|
||||
r"feed_forward.w1": r"mlp.gate_proj",
|
||||
r"feed_forward.w2": r"mlp.down_proj",
|
||||
r"feed_forward.w3": r"mlp.up_proj",
|
||||
r"attention_norm": r"input_layernorm",
|
||||
r"ffn_norm": r"post_attention_layernorm",
|
||||
r"output": r"lm_head",
|
||||
}
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI = {
|
||||
# Vision encoder mapping
|
||||
r"mlp1.0": r"multi_modal_projector.layer_norm",
|
||||
r"mlp1.1": r"multi_modal_projector.linear_1",
|
||||
r"mlp1.3": r"multi_modal_projector.linear_2",
|
||||
}
|
||||
|
||||
|
||||
chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||
"{% if message['content'] is string %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% else %}"
|
||||
"{% for content in message['content'] %}"
|
||||
"{% if content['type'] == 'image' %}"
|
||||
"{{ '<image>\n' }}"
|
||||
"{% elif content['type'] == 'video' %}"
|
||||
"{{ '<video>\n' }}"
|
||||
"{% elif content['type'] == 'text' %}"
|
||||
"{{ content['text'] }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% endif %}"
|
||||
"{{'<|im_end|>\n'}}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{'<|im_start|>assistant\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
CONTEXT_LENGTH = 8192
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: dict = None, path: str = None):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text_vision = "\n".join([key for key in state_dict_keys if key.startswith("vision_model")])
|
||||
new_text = old_text_vision
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION.items():
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n")))
|
||||
old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")])
|
||||
new_text = old_text_language
|
||||
if LM_TYPE_CORRESPONDENCE[path] == "llama":
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items():
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict.update(dict(zip(old_text_language.split("\n"), new_text.split("\n"))))
|
||||
old_text_multi = "\n".join(
|
||||
[
|
||||
key
|
||||
for key in state_dict_keys
|
||||
if not (key.startswith("language_model") or key.startswith("vision_model"))
|
||||
]
|
||||
)
|
||||
new_text = old_text_multi
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI.items():
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict.update(dict(zip(old_text_multi.split("\n"), new_text.split("\n"))))
|
||||
|
||||
return output_dict
|
||||
|
||||
|
||||
def load_original_state_dict(input_base_path):
|
||||
model = AutoModel.from_pretrained(
|
||||
input_base_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
use_flash_attn=False,
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
def get_internvl_config(input_base_path):
|
||||
base_config = AutoModel.from_pretrained(input_base_path, trust_remote_code=True).config
|
||||
llm_config = base_config.llm_config.to_dict()
|
||||
vision_config = base_config.vision_config.to_dict()
|
||||
vision_config["use_absolute_position_embeddings"] = True
|
||||
if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
|
||||
image_token_id = 151667
|
||||
language_config_class = Qwen2Config
|
||||
else:
|
||||
image_token_id = 92546
|
||||
language_config_class = LlamaConfig
|
||||
|
||||
llm_config = {k: v for k, v in llm_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||
# Force use_cache to True
|
||||
llm_config["use_cache"] = True
|
||||
# Force correct eos_token_id for InternVL3
|
||||
if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
|
||||
llm_config["eos_token_id"] = 151645
|
||||
|
||||
vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||
if "attention_probs_dropout_prob" in vision_config:
|
||||
attention_dropout = vision_config.pop("attention_probs_dropout_prob")
|
||||
vision_config["attention_dropout"] = attention_dropout
|
||||
vision_config["projection_dropout"] = attention_dropout
|
||||
if "qk_normalization" in vision_config:
|
||||
use_qk_norm = vision_config.pop("qk_normalization")
|
||||
vision_config["use_qk_norm"] = use_qk_norm
|
||||
if "qkv_bias" in vision_config:
|
||||
attention_bias = vision_config.pop("qkv_bias")
|
||||
vision_config["attention_bias"] = attention_bias
|
||||
|
||||
return InternVLConfig(
|
||||
text_config=language_config_class(**llm_config),
|
||||
vision_config=InternVLVisionConfig(**vision_config),
|
||||
image_token_id=image_token_id,
|
||||
)
|
||||
|
||||
|
||||
def write_model(
|
||||
model_path,
|
||||
input_base_path,
|
||||
push_to_hub=False,
|
||||
hub_dir=None,
|
||||
):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
config = get_internvl_config(input_base_path)
|
||||
config.architectures = ["InternVLForConditionalGeneration"]
|
||||
config.save_pretrained(model_path)
|
||||
if push_to_hub:
|
||||
config.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
print("Model config saved successfully...")
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Convert weights
|
||||
# ------------------------------------------------------------
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
|
||||
state_dict_old = load_original_state_dict(input_base_path)
|
||||
print("Converting model...")
|
||||
all_keys = list(state_dict_old.keys())
|
||||
new_keys = convert_old_keys_to_new_keys(all_keys, path=input_base_path)
|
||||
lm_dim = config.text_config.hidden_size
|
||||
dim = config.vision_config.hidden_size
|
||||
state_dict = {}
|
||||
for key in all_keys:
|
||||
new_key = new_keys[key]
|
||||
if "attn.qkv" in key:
|
||||
new_key_query = new_key.replace("attention.qkv", "attention.q_proj")
|
||||
state_dict[new_key_query] = state_dict_old[key][:dim]
|
||||
|
||||
new_key_key = new_key.replace("attention.qkv", "attention.k_proj")
|
||||
state_dict[new_key_key] = state_dict_old[key][dim : 2 * dim]
|
||||
|
||||
new_key_value = new_key.replace("attention.qkv", "attention.v_proj")
|
||||
state_dict[new_key_value] = state_dict_old[key][-dim:]
|
||||
elif "attention.wqkv" in key:
|
||||
num_key_value_groups = config.text_config.num_attention_heads // config.text_config.num_key_value_heads
|
||||
head_dim = config.text_config.head_dim
|
||||
wqkv_weights = state_dict_old[key]
|
||||
|
||||
qkv_vecs = rearrange(wqkv_weights, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim)
|
||||
q_proj = qkv_vecs[:, :num_key_value_groups, ...].reshape(-1, lm_dim).contiguous()
|
||||
k_proj = qkv_vecs[:, -2, ...].reshape(-1, lm_dim).contiguous()
|
||||
v_proj = qkv_vecs[:, -1, ...].reshape(-1, lm_dim).contiguous()
|
||||
|
||||
new_key_query = new_key.replace("attention.wqkv", "self_attn.q_proj")
|
||||
state_dict[new_key_query] = q_proj
|
||||
|
||||
new_key_key = new_key.replace("attention.wqkv", "self_attn.k_proj")
|
||||
state_dict[new_key_key] = k_proj
|
||||
|
||||
new_key_value = new_key.replace("attention.wqkv", "self_attn.v_proj")
|
||||
state_dict[new_key_value] = v_proj
|
||||
else:
|
||||
state_dict[new_key] = state_dict_old[key]
|
||||
|
||||
del state_dict_old
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a InternVLForConditionalGeneration model.")
|
||||
model = InternVLForConditionalGeneration(config)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
model = model.to(torch.bfloat16)
|
||||
print("model dtype:", model.dtype)
|
||||
print("Missing keys:", missing_keys)
|
||||
print("Unexpected keys:", unexpected_keys)
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(model_path)
|
||||
if push_to_hub:
|
||||
model.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
|
||||
image_processor = GotOcr2ImageProcessorFast.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
processor = InternVLProcessor(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
|
||||
processor.save_pretrained(model_path)
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
|
||||
# generation config
|
||||
if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama":
|
||||
print("Saving generation config...")
|
||||
# in the original model, eos_token is not the same in the text_config and the generation_config
|
||||
# ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config)
|
||||
generation_config = GenerationConfig(
|
||||
eos_token_id=92542,
|
||||
)
|
||||
generation_config.save_pretrained(model_path)
|
||||
if push_to_hub:
|
||||
generation_config.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
|
||||
# del state_dict, model
|
||||
|
||||
# # Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
model = InternVLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
print("Model reloaded successfully.")
|
||||
del model
|
||||
|
||||
|
||||
def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None, hub_dir: str = None):
|
||||
if LM_TYPE_CORRESPONDENCE[path] == "qwen2":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
return_token_type_ids=False,
|
||||
extra_special_tokens={
|
||||
"start_image_token": "<img>",
|
||||
"end_image_token": "</img>",
|
||||
"context_image_token": "<IMG_CONTEXT>",
|
||||
},
|
||||
)
|
||||
tokenizer.model_max_length = CONTEXT_LENGTH
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<img>",
|
||||
"</img>",
|
||||
"<IMG_CONTEXT>",
|
||||
"<quad>",
|
||||
"</quad>",
|
||||
"<ref>",
|
||||
"</ref>",
|
||||
"<box>",
|
||||
"</box>",
|
||||
]
|
||||
},
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
else:
|
||||
# Obtained with:
|
||||
# tokenizer_llama_fast = LlamaTokenizerFast.from_pretrained(
|
||||
# "OpenGVLab/InternVL2_5-2B-MPO", pad_token="</s>", legacy=False, from_slow=True
|
||||
# )
|
||||
# tokenizer_llama_fast._tokenizer.pre_tokenizer.prepend_scheme = "never"
|
||||
# Then manually modifying `added_tokens_decoder` indices to match the original tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"./intern_vl_hf_implem/tokenizer_internvl_llama_fast",
|
||||
return_token_type_ids=False,
|
||||
extra_special_tokens={
|
||||
"start_image_token": "<img>",
|
||||
"end_image_token": "</img>",
|
||||
"context_image_token": "<IMG_CONTEXT>",
|
||||
},
|
||||
)
|
||||
|
||||
tokenizer.chat_template = chat_template
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
if push_to_hub:
|
||||
tokenizer.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
|
||||
|
||||
def write_image_processor(save_dir: str, push_to_hub: bool = False, hub_dir: str = None):
|
||||
image_processor = GotOcr2ImageProcessorFast(
|
||||
do_resize=True,
|
||||
size={"height": 448, "width": 448},
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
image_mean=[0.485, 0.456, 0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
do_convert_rgb=True,
|
||||
)
|
||||
|
||||
image_processor.save_pretrained(save_dir)
|
||||
if push_to_hub:
|
||||
image_processor.push_to_hub(hub_dir, use_temp_dir=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
default="OpenGVLab/InternVL3-1B",
|
||||
help="Location of original InternVL model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="InternVL3-1B-hf",
|
||||
help="Location to write HF model and processors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_dir",
|
||||
default="OpenGVLab/InternVL3-1B-hf",
|
||||
help="Location to write HF model and processors",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_tokenizer(
|
||||
save_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
path=args.input_dir,
|
||||
hub_dir=args.hub_dir,
|
||||
)
|
||||
|
||||
write_image_processor(
|
||||
save_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
hub_dir=args.hub_dir,
|
||||
)
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=args.input_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
hub_dir=args.hub_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1094
src/transformers/models/internvl/modeling_internvl.py
Normal file
1094
src/transformers/models/internvl/modeling_internvl.py
Normal file
File diff suppressed because it is too large
Load Diff
708
src/transformers/models/internvl/modular_internvl.py
Normal file
708
src/transformers/models/internvl/modular_internvl.py
Normal file
@ -0,0 +1,708 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections.abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ..clip.modeling_clip import CLIPMLP
|
||||
from ..janus.modeling_janus import JanusVisionAttention
|
||||
from ..llama.modeling_llama import LlamaRMSNorm
|
||||
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
|
||||
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "OpenGVLab/InternVL3-1B-hf"
|
||||
|
||||
_CONFIG_VISION_FOR_DOC = "InternVLVisionConfig"
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = key
|
||||
value_states = value
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# No upcasting of the attention weights to float32 in this implementation
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class InternVLVisionRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class InternVLVisionAttention(JanusVisionAttention):
|
||||
def __init__(self, config: InternVLVisionConfig):
|
||||
super().__init__()
|
||||
del self.num_key_value_groups
|
||||
|
||||
# Needed for flash attention
|
||||
self.is_causal = False
|
||||
qk_norm = config.use_qk_norm
|
||||
|
||||
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||
|
||||
|
||||
class InternVLVisionPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = InternVLVisionConfig
|
||||
base_model_prefix = "internvl_vision"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["InternVLVisionLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, InternVLVisionEmbeddings):
|
||||
module.cls_token.data.zero_()
|
||||
if module.mask_token is not None:
|
||||
module.mask_token.data.zero_()
|
||||
if module.position_embeddings is not None:
|
||||
module.position_embeddings.data.zero_()
|
||||
elif isinstance(module, InternVLVisionLayer):
|
||||
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
|
||||
"""
|
||||
Class for outputs of [`InternVLVisionModel`].
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
||||
Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
|
||||
*config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
|
||||
will be returned.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
|
||||
class InternVLVisionPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||
Transformer.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
image_size, patch_size = config.image_size, config.patch_size
|
||||
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.num_patches = num_patches
|
||||
self.patch_shape = patch_shape
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
|
||||
embeddings = self.projection(pixel_values)
|
||||
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
# Based on timm implementation, which can be found here:
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
class InternVLVisionEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
if config.use_mask_token:
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
else:
|
||||
self.mask_token = None
|
||||
self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
|
||||
self.patch_size = config.patch_size
|
||||
self.image_size = (
|
||||
config.image_size
|
||||
if isinstance(config.image_size, collections.abc.Iterable)
|
||||
else (config.image_size, config.image_size)
|
||||
)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
if config.use_absolute_position_embeddings:
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
else:
|
||||
self.position_embeddings = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
||||
# replace the masked visual tokens by mask_tokens
|
||||
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||
embeddings = embeddings * (1 - w) + mask_tokens * w
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
class InternVLVisionMLP(CLIPMLP):
|
||||
pass
|
||||
|
||||
|
||||
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
|
||||
|
||||
|
||||
class InternVLVisionLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = InternVLVisionAttention(config)
|
||||
self.mlp = InternVLVisionMLP(config)
|
||||
# InternVL uses different layernorm implementations for different models
|
||||
self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
init_values = config.layer_scale_init_value
|
||||
self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
||||
self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
attention_output, attention_weights = self.attention(
|
||||
self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = self.lambda_1 * attention_output
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
# in InternVLVision, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
|
||||
layer_output = self.mlp(layer_output)
|
||||
layer_output = self.dropout(layer_output)
|
||||
|
||||
if self.lambda_2 is not None:
|
||||
layer_output = self.lambda_2 * layer_output
|
||||
|
||||
# second residual connection
|
||||
layer_output = layer_output + hidden_states
|
||||
|
||||
return layer_output, attention_weights
|
||||
|
||||
|
||||
class InternVLVisionEncoder(nn.Module):
|
||||
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, output_attentions
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
||||
|
||||
_CONFIG_VISION_FOR_DOC = "InternVLVisionConfig"
|
||||
|
||||
|
||||
INTERNVL_VISION_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`InternVLVisionConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
INTERNVL_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||||
[`InternVLVisionImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare InternVLVision Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
INTERNVL_VISION_START_DOCSTRING,
|
||||
)
|
||||
class InternVLVisionModel(InternVLVisionPreTrainedModel):
|
||||
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = InternVLVisionEmbeddings(config)
|
||||
self.encoder = InternVLVisionEncoder(config)
|
||||
|
||||
self.layernorm = (
|
||||
nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INTERNVL_VISION_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=InternVLVisionModelOutputWithPooling,
|
||||
config_class=_CONFIG_VISION_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
|
||||
r"""
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
|
||||
return InternVLVisionModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "InternVLConfig"
|
||||
|
||||
|
||||
class InternVLPreTrainedModel(LlavaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
INTERNVL_INPUTS_DOCSTRING = None
|
||||
|
||||
|
||||
class InternVLMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: InternVLConfig):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.layer_norm(image_features)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
|
||||
"""Perform pixel shuffle downsampling on vision features.
|
||||
|
||||
Args:
|
||||
vision_features (`torch.Tensor`):
|
||||
Input tensor of shape (batch_size, width, height, channels).
|
||||
scale_factor (`float`, *optional*, defaults to `0.5`):
|
||||
Factor by which to downsample. Default is 0.5, which halves the dimensions.
|
||||
|
||||
Returns:
|
||||
vision_features (`torch.Tensor`):
|
||||
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
|
||||
"""
|
||||
batch_size, width, height, channels = vision_features.size()
|
||||
|
||||
if height % scale_factor != 0 or width % scale_factor != 0:
|
||||
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
|
||||
|
||||
# Reshape to allow downsampling
|
||||
vision_features = vision_features.view(
|
||||
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
|
||||
)
|
||||
# Permute dimensions to align downsampled axis correctly
|
||||
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# Reshape to achieve final downsampled dimensions
|
||||
vision_features = vision_features.view(
|
||||
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
|
||||
)
|
||||
|
||||
# Swap height and width back for proper orientation
|
||||
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
return vision_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`int` or `List[int]`):
|
||||
Layer index or list of layer indices to extract features from.
|
||||
Returns:
|
||||
vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
||||
"""
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
if vision_feature_layer == -1:
|
||||
vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
else:
|
||||
vision_features = self.vision_model(pixel_values=pixel_values).hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
vision_features = vision_features[:, 1:, :]
|
||||
|
||||
# Calculate dimensions based on vision features
|
||||
channels = vision_features.shape[1]
|
||||
feature_size = int(channels**0.5)
|
||||
batch_size = vision_features.shape[0]
|
||||
|
||||
# Reshape tensor to spatial dimensions
|
||||
vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
|
||||
|
||||
# Apply downsampling using pixel shuffle
|
||||
vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
|
||||
|
||||
# Reshape tensor to prepare for projection
|
||||
vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
|
||||
|
||||
# Project features through multi-modal projector
|
||||
vision_features = self.multi_modal_projector(vision_features)
|
||||
|
||||
return vision_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(
|
||||
... "OpenGVLab/InternVL3-1B-hf", torch_dtype=torch.bfloat16, device_map=torch_device
|
||||
... )
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {
|
||||
... "type": "image",
|
||||
... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
... },
|
||||
... {
|
||||
... "type": "image",
|
||||
... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
... },
|
||||
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||
... ],
|
||||
... },
|
||||
... ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=200)
|
||||
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
|
||||
The images depict the Statue of Liberty and the Golden Gate Bridge.
|
||||
```"""
|
||||
super().forward(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InternVLVisionPreTrainedModel",
|
||||
"InternVLVisionModel",
|
||||
"InternVLPreTrainedModel",
|
||||
"InternVLForConditionalGeneration",
|
||||
]
|
378
src/transformers/models/internvl/processing_internvl.py
Normal file
378
src/transformers/models/internvl/processing_internvl.py
Normal file
@ -0,0 +1,378 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.processing_utils import (
|
||||
AllKwargsForChatTemplate,
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
VideoInput,
|
||||
VideoMetadata,
|
||||
concatenate_list,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
)
|
||||
|
||||
|
||||
class InternVLImagesKwargs(ImagesKwargs, total=False):
|
||||
crop_to_patches: Optional[bool]
|
||||
min_patches: Optional[int]
|
||||
max_patches: Optional[int]
|
||||
|
||||
|
||||
class InternVLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: InternVLImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding_side": "left",
|
||||
},
|
||||
"images_kwargs": {
|
||||
"crop_to_patches": True,
|
||||
},
|
||||
"videos_kwargs": {
|
||||
"crop_to_patches": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class InternVLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a InternVL processor which wraps a [`AutoImageProcessor`] and
|
||||
[`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
|
||||
tokenizer functionalities. See the [`~InternVLProcessor.__call__`] and [`~InternVLProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`AutoImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
image_seq_length (`int`, *optional*, defaults to 256):
|
||||
The number of image token to use per image patch. it should be set so that:
|
||||
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
The token to use for the image placeholder in the text. This token will be replaced by the
|
||||
appropriate image tokens when processing the text with images.
|
||||
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
The token to use for the video placeholder in the text. This token will be replaced by the
|
||||
appropriate image tokens when processing the text with videos.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"image_seq_length",
|
||||
"fake_image_token",
|
||||
"fake_video_token",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
image_seq_length: int = 256,
|
||||
chat_template=None,
|
||||
fake_image_token="<image>",
|
||||
fake_video_token="<video>",
|
||||
**kwargs,
|
||||
):
|
||||
self.image_seq_length = image_seq_length
|
||||
self.fake_image_token = fake_image_token
|
||||
self.fake_video_token = fake_video_token
|
||||
self.start_image_token = tokenizer.start_image_token
|
||||
self.end_image_token = tokenizer.end_image_token
|
||||
self.context_image_token = tokenizer.context_image_token
|
||||
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||
|
||||
def _insert_media_placeholders(
|
||||
self,
|
||||
text: list[str],
|
||||
image_pixel_values,
|
||||
video_pixel_values,
|
||||
image_num_patches: list[int],
|
||||
video_num_patches: list[int],
|
||||
image_num_patches_indices: np.ndarray,
|
||||
video_num_patches_indices: np.ndarray,
|
||||
video_patch_indices: np.ndarray,
|
||||
):
|
||||
"""
|
||||
Processes interleaved text with <image> and <video> placeholders, replacing them with appropriate
|
||||
image and video tokens while keeping track of the patches used.
|
||||
"""
|
||||
image_index = 0
|
||||
video_index = 0
|
||||
processed_text = []
|
||||
image_video_patches = []
|
||||
# Support interleaved image and video in prompts:
|
||||
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
|
||||
for prompt in text:
|
||||
new_prompt = prompt
|
||||
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt:
|
||||
if self.fake_image_token in new_prompt and (
|
||||
self.fake_video_token not in new_prompt
|
||||
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token)
|
||||
):
|
||||
# Get the slice of patches corresponding to the current image
|
||||
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
|
||||
end_index = image_num_patches_indices[image_index]
|
||||
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||
# Replace the corresponding image placeholder with the correct number of image tokens
|
||||
new_prompt = new_prompt.replace(
|
||||
self.fake_image_token,
|
||||
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}",
|
||||
1,
|
||||
)
|
||||
image_index += 1
|
||||
else:
|
||||
# Get the slice of patches corresponding to the current video
|
||||
# Here we need to account for both the multiple video frames and the potential multiple patches per frame
|
||||
# As of now, InternVL only supports one patch per frame, but we keep the code flexible for future updates
|
||||
current_patch_index = video_patch_indices[video_index - 1] if video_index > 0 else 0
|
||||
end_patch_index = video_patch_indices[video_index]
|
||||
start_index = video_num_patches_indices[current_patch_index] if video_index > 0 else 0
|
||||
end_index = video_num_patches_indices[end_patch_index - 1]
|
||||
image_video_patches.append(video_pixel_values[start_index:end_index])
|
||||
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
|
||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||
video_prompt = "\n".join(
|
||||
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
|
||||
for i in range(len(num_patches))
|
||||
)
|
||||
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1)
|
||||
video_index += 1
|
||||
processed_text.append(new_prompt)
|
||||
|
||||
return processed_text, image_video_patches, image_index, video_index
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[ImageInput] = None,
|
||||
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
audio=None,
|
||||
videos: Optional[VideoInput] = None,
|
||||
**kwargs: Unpack[InternVLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
|
||||
is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
|
||||
`crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to
|
||||
GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
if text is None:
|
||||
raise ValueError("You have to specify text.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
InternVLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not isinstance(text, (list, tuple)):
|
||||
text = [text]
|
||||
|
||||
# Process images and videos separately, as videos don't support crop_to_patches
|
||||
image_num_patches = []
|
||||
video_num_patches = []
|
||||
image_videos_inputs = {}
|
||||
image_pixel_values = None
|
||||
video_pixel_values = None
|
||||
image_num_patches_indices = np.array([0])
|
||||
video_patch_indices = np.array([0])
|
||||
video_num_patches_indices = np.array([0])
|
||||
if images is not None:
|
||||
images = make_flat_list_of_images(images)
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_num_patches = image_inputs.pop("num_patches")
|
||||
image_pixel_values = image_inputs.pop("pixel_values")
|
||||
image_num_patches_indices = np.cumsum(image_num_patches)
|
||||
if videos is not None:
|
||||
videos = make_batched_videos(videos)
|
||||
num_frames_per_video = [len(video) for video in videos]
|
||||
video_patch_indices = np.cumsum(num_frames_per_video)
|
||||
output_kwargs["images_kwargs"]["crop_to_patches"] = False
|
||||
video_inputs = self.image_processor(images=videos, **output_kwargs["videos_kwargs"])
|
||||
video_num_patches = video_inputs.pop("num_patches")
|
||||
video_pixel_values = video_inputs.pop("pixel_values")
|
||||
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||
|
||||
if images is not None or videos is not None:
|
||||
text, image_video_patches, image_index, video_index = self._insert_media_placeholders(
|
||||
text,
|
||||
image_pixel_values,
|
||||
video_pixel_values,
|
||||
image_num_patches,
|
||||
video_num_patches,
|
||||
image_num_patches_indices,
|
||||
video_num_patches_indices,
|
||||
video_patch_indices,
|
||||
)
|
||||
if images is not None and image_index != len(images):
|
||||
raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
|
||||
if videos is not None and video_index != len(videos):
|
||||
raise ValueError("Number of video placeholders in the prompt does not match the number of videos.")
|
||||
|
||||
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
|
||||
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_videos_inputs})
|
||||
|
||||
def sample_indices_fn(
|
||||
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
|
||||
):
|
||||
"""
|
||||
The function to generate indices of frames to sample from a video.
|
||||
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
`VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If None, all frames are sampled.
|
||||
initial_shift (`bool`, `float` or `int`, defaults to `0`):
|
||||
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Array of frame indices to sample.
|
||||
"""
|
||||
if initial_shift is True:
|
||||
initial_shift = metadata.total_num_frames / num_frames / 2
|
||||
if num_frames is not None:
|
||||
indices = np.arange(
|
||||
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames
|
||||
).astype(int)
|
||||
else:
|
||||
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
|
||||
|
||||
return indices
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(tokenizer_input_names) + list(image_processor_input_names)
|
||||
|
||||
# Add model-specific video sampling method when applying the template
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
chat_template: Optional[str] = None,
|
||||
num_frames: int = 8,
|
||||
initial_shift: Union[bool, float, int] = True,
|
||||
video_load_backend="pyav",
|
||||
**kwargs: Unpack[AllKwargsForChatTemplate],
|
||||
):
|
||||
"""
|
||||
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
|
||||
conversations to turn them into a single tokenizable string.
|
||||
|
||||
The input is expected to be in the following format, where each message content is a list consisting of text and
|
||||
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
|
||||
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "Please describe this image in detail."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
Args:
|
||||
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
|
||||
The conversation to format.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||
chat template is used.
|
||||
num_frames (`int`, *optional*, defaults to 8):
|
||||
Number of frames to sample from a video when using the default `sample_indices_fn`.
|
||||
initial_shift (`bool`, `float` or `int`, defaults to `0`):
|
||||
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
|
||||
If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||
"""
|
||||
sample_indices_fn = kwargs.pop(
|
||||
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
|
||||
)
|
||||
|
||||
return super().apply_chat_template(
|
||||
conversation,
|
||||
chat_template,
|
||||
video_load_backend=video_load_backend,
|
||||
num_frames=num_frames,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["InternVLProcessor"]
|
@ -130,6 +130,7 @@ VLM_CLASS_NAMES = [
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
"internvl",
|
||||
"qwen2_5_omni",
|
||||
]
|
||||
|
||||
|
0
tests/models/internvl/__init__.py
Normal file
0
tests/models/internvl/__init__.py
Normal file
894
tests/models/internvl/test_modeling_internvl.py
Normal file
894
tests/models/internvl/test_modeling_internvl.py
Normal file
@ -0,0 +1,894 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch InternVL model."""
|
||||
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
BitsAndBytesConfig,
|
||||
InternVLConfig,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_av,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
InternVLForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class InternVLVisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
seq_length=7,
|
||||
image_seq_length=64,
|
||||
vision_feature_layer=-1,
|
||||
ignore_index=-100,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_id=1,
|
||||
num_channels=3,
|
||||
image_size=64,
|
||||
model_type="internvl",
|
||||
is_training=True,
|
||||
text_config={
|
||||
"model_type": "qwen2",
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 37,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"output_channels": 64,
|
||||
"hidden_act": "silu",
|
||||
"max_position_embeddings": 512,
|
||||
"rope_theta": 10000,
|
||||
"mlp_ratio": 4,
|
||||
"tie_word_embeddings": True,
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"pad_token_id": 0,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 128,
|
||||
"image_size": 64,
|
||||
"patch_size": 4,
|
||||
"num_channels": 3,
|
||||
"hidden_act": "quick_gelu",
|
||||
"use_absolute_position_embeddings": True,
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.image_token_id = image_token_id
|
||||
self.model_type = model_type
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.batch_size = batch_size
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.is_training = is_training
|
||||
self.image_seq_length = image_seq_length
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.seq_length = seq_length + image_seq_length
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
|
||||
def get_config(self):
|
||||
return InternVLConfig(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
model_type=self.model_type,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
image_token_id=self.image_token_id,
|
||||
image_seq_length=self.image_seq_length,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
# input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, : self.image_seq_length] = self.image_token_id
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
model = InternVLForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
config.torch_dtype = torch.float16
|
||||
model = InternVLForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"image-text-to-text": InternVLForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = InternVLVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=InternVLConfig, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Qwen2 flash attention does not support right padding")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.small_model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||
self.medium_model_checkpoint = "OpenGVLab/InternVL3-2B-hf"
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_qwen2_small_model_integration_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_qwen2_small_model_integration_forward(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
# Forward
|
||||
with torch.inference_mode():
|
||||
output = model(**inputs)
|
||||
|
||||
actual_logits = output.logits[0, -1, :5].cpu()
|
||||
expected_logits = torch.tensor([11.9375, 14.8750, 14.0625, 10.7500, 6.9062], dtype=torch.bfloat16)
|
||||
self.assertTrue(
|
||||
torch.allclose(actual_logits, expected_logits, atol=0.1),
|
||||
f"Actual logits: {actual_logits}"
|
||||
f"\nExpected logits: {expected_logits}"
|
||||
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
|
||||
)
|
||||
|
||||
def test_qwen2_small_model_integration_generate_text_only(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
prompt = "<|im_start|>user\nWrite a haiku<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_qwen2_small_model_integration_generate_chat_template(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||
{"type": "text", "text": "Please describe the image explicitly."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_qwen2_small_model_integration_batched_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
def test_qwen2_small_model_integration_batched_generate_multi_image(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(
|
||||
BytesIO(
|
||||
requests.get(
|
||||
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
).content
|
||||
)
|
||||
)
|
||||
image3 = Image.open(
|
||||
BytesIO(
|
||||
requests.get(
|
||||
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||
).content
|
||||
)
|
||||
)
|
||||
|
||||
inputs = processor(text=prompt, images=[[image1], [image2, image3]], padding=True, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = 'user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Angle' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@require_av
|
||||
@require_bitsandbytes
|
||||
def test_qwen2_medium_model_integration_video(self):
|
||||
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.medium_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
# Prepare inputs
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
expected_output = 'The man is performing a forehand shot.' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@require_av
|
||||
def test_qwen2_small_model_integration_interleaved_images_videos(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, torch_dtype=torch.bfloat16, device_map=torch_device
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
},
|
||||
{"type": "text", "text": "What are the differences between these two images?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://llava-vl.github.io/static/images/view.jpg",
|
||||
},
|
||||
{"type": "text", "text": "Write a haiku for this image"},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||
expected_output = 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image**: This shows the Statue of Liberty on Liberty Island, with the' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check third output
|
||||
decoded_output = processor.decode(output[2], skip_special_tokens=True)
|
||||
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.small_model_checkpoint = "OpenGVLab/InternVL2_5-2B-MPO-hf"
|
||||
self.medium_model_checkpoint = "OpenGVLab/InternVL2_5-8B-MPO-hf"
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_llama_small_model_integration_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "The image shows two cats sleeping on a pink couch. They are lying side by side, with their"
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_llama_small_model_integration_forward(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
# Forward
|
||||
with torch.inference_mode():
|
||||
output = model(**inputs)
|
||||
|
||||
actual_logits = output.logits[0, -1, :5].cpu()
|
||||
expected_logits = torch.tensor([-9.8750, -0.4258, 1.4844, -10.3125, -10.3125], dtype=torch.bfloat16)
|
||||
# The original implementation and the transformers implementation do not match exactly, hence the higher tolerance.
|
||||
# The difference is likely due to the different implementations of the attention mechanism (different order of operations)
|
||||
# between the transformers Llama model and the original InternLM model.
|
||||
# The difference has almost no effect on the output tokens, but it does affect the logits a lot more.
|
||||
self.assertTrue(
|
||||
torch.allclose(actual_logits, expected_logits, atol=1),
|
||||
f"Actual logits: {actual_logits}"
|
||||
f"\nExpected logits: {expected_logits}"
|
||||
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
|
||||
)
|
||||
|
||||
def test_llama_small_model_integration_generate_text_only(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
prompt = "<|im_start|>user\nWrite a haiku<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "Autumn leaves fall,\nNature's breath, a season's sigh,\nSilent woods awake."
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_llama_small_model_integration_generate_chat_template(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||
{"type": "text", "text": "Please describe the image explicitly."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
with torch.no_grad():
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
decoded_output = processor.decode(
|
||||
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
expected_output = "The image shows two cats sleeping on a pink couch. They are lying side by side, with their"
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
def test_llama_small_model_integration_batched_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese gate in the background, adorned with red and gold colors and Chinese characters' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
def test_llama_small_model_integration_batched_generate_multi_image(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(
|
||||
BytesIO(
|
||||
requests.get(
|
||||
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
).content
|
||||
)
|
||||
)
|
||||
image3 = Image.open(
|
||||
BytesIO(
|
||||
requests.get(
|
||||
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||
).content
|
||||
)
|
||||
)
|
||||
|
||||
inputs = processor(text=prompt, images=[[image1], [image2, image3]], padding=True, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, still waters.' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = 'user\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After closely examining the images again, I can see that there are several differences' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@require_av
|
||||
@require_bitsandbytes
|
||||
def test_llama_medium_model_integration_video(self):
|
||||
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.medium_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
# Prepare inputs
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
expected_output = "The man is performing a forehand shot."
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@require_av
|
||||
def test_llama_small_model_integration_interleaved_images_videos(self):
|
||||
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||
model = InternVLForConditionalGeneration.from_pretrained(
|
||||
self.small_model_checkpoint, torch_dtype=torch.bfloat16, device_map=torch_device
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
},
|
||||
{"type": "text", "text": "What are the difference between these two images?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://llava-vl.github.io/static/images/view.jpg",
|
||||
},
|
||||
{"type": "text", "text": "Write a haiku for this image"},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, dtype=torch.bfloat16)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||
expected_output = 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot. This is a common shot in tennis where the player swings the racket across their' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check third output
|
||||
decoded_output = processor.decode(output[2], skip_special_tokens=True)
|
||||
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, untouched dreams.' # fmt: skip
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
327
tests/models/internvl/test_processor_internvl.py
Normal file
327
tests/models/internvl/test_processor_internvl.py
Normal file
@ -0,0 +1,327 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import GotOcr2ImageProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = InternVLProcessor
|
||||
videos_input_name = "pixel_values"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = GotOcr2ImageProcessor(
|
||||
do_resize=True,
|
||||
size={"height": 20, "width": 20},
|
||||
max_patches=2,
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
do_center_crop=True,
|
||||
image_mean=[0.485, 0.456, 0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
do_convert_rgb=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B-hf", padding_side="left")
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = InternVLProcessor.from_pretrained(
|
||||
"OpenGVLab/InternVL3-1B-hf",
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
cls.image_token = processor.fake_image_token
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {"image_seq_length": 10}
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
@require_av
|
||||
@require_torch
|
||||
def test_process_interleaved_images_videos(self):
|
||||
processor = self.get_processor()
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
},
|
||||
{"type": "text", "text": "What are the differences between these two images?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://llava-vl.github.io/static/images/view.jpg",
|
||||
},
|
||||
{"type": "text", "text": "Write a haiku for this image"},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
inputs_batched = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
|
||||
images_patches_index = 0
|
||||
for i, message in enumerate(messages):
|
||||
inputs = processor.apply_chat_template(
|
||||
message,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
|
||||
torch.testing.assert_close(
|
||||
inputs["input_ids"][0], inputs_batched["input_ids"][i][-inputs["input_ids"].shape[1] :]
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
inputs["pixel_values"],
|
||||
inputs_batched["pixel_values"][
|
||||
images_patches_index : images_patches_index + inputs["pixel_values"].shape[0]
|
||||
],
|
||||
)
|
||||
images_patches_index += inputs["pixel_values"].shape[0]
|
||||
|
||||
# Override video chat_template tests as InternVLProcessor returns flattened video features
|
||||
@require_av
|
||||
def test_apply_chat_template_video_special_processing(self):
|
||||
"""
|
||||
Tests that models can use their own preprocessing to preprocess conversations.
|
||||
"""
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": video_file_path},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
def _process_messages_for_chat_template(
|
||||
conversation,
|
||||
batch_images,
|
||||
batch_videos,
|
||||
batch_video_metadata,
|
||||
**chat_template_kwargs,
|
||||
):
|
||||
# Let us just always return a dummy prompt
|
||||
new_msg = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
return new_msg
|
||||
|
||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
||||
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||
|
||||
def test_apply_chat_template_video_frame_sampling(self):
|
||||
processor = self.get_processor()
|
||||
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
signature = inspect.signature(processor.__call__)
|
||||
if "videos" not in {*signature.parameters.keys()} or (
|
||||
signature.parameters.get("videos") is not None
|
||||
and signature.parameters["videos"].annotation == inspect._empty
|
||||
):
|
||||
self.skipTest("Processor doesn't accept videos at input")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
num_frames = 3
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
|
||||
|
||||
# Load with `video_fps` arg
|
||||
video_fps = 1
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=None, # force to use default num_frames
|
||||
return_tensors="np",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
|
||||
|
||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
# Load without any arg should use the default loading method
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||
|
||||
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||
# because we assume they come from one video
|
||||
messages[0][0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": [
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||
],
|
||||
}
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2)
|
@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
|
||||
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"Emu3VQVAE", # Building part of bigger (tested) model
|
||||
"Emu3TextModel", # Building part of bigger (tested) model
|
||||
"InternVLVisionModel", # Building part of bigger (tested) model
|
||||
"JanusVisionModel", # Building part of bigger (tested) model
|
||||
"TimesFmModel", # Building part of bigger (tested) model
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user