mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Add InternVL (2.5 MPO) (#35968)
* initial commit * add convert internvl * add first end-to-end working internvl * nit prompt and image proc * add working chat template * add conversion llama-based models * add tests * pass all tests * fix isort * fix modular after main merge * add video processing for internvl * add support for interlaced images and videos * Remove processing and config from modular, add more tests * add llama model tests * Modify processor for compatibility with refactored got ocr image processor * add comments in processor * Add docs and nits * change video processing to use custom sample_indices_fn * rebase and fix tests * add processor tests * Add changes Raushan review * Use the new attention interface for the vision model * nits * add support for custom video_load_backend * remove mention to InternVLTokenizer * refactor vision model to simplify logic * refactor processor for better readibility * fix copies * fix require av processor test * refactor internVL vision * Update processor and fix processing tests * fix docstring * update convert_weights for internvl3 * change image processor to fast by default * remove do_center_crop=True in convert_weights * force use_cache to True * push_to_hub before reloading * fix internVLVision for larger models * update convert weight for qk norm * fix convert_weights * fix eos_token_id in convert * update docs and integration tests * make modifs after review * fix wrong k_norm and reduce modular * change image_token_index to image_token_id * change checkpoint to OpenGVLab org * last nits * explicitely del self.num_key_value_groups * add extra special tokens
This commit is contained in:
parent
b0c6ff5e13
commit
a245011252
@ -953,6 +953,8 @@
|
|||||||
title: InstructBLIP
|
title: InstructBLIP
|
||||||
- local: model_doc/instructblipvideo
|
- local: model_doc/instructblipvideo
|
||||||
title: InstructBlipVideo
|
title: InstructBlipVideo
|
||||||
|
- local: model_doc/internvl
|
||||||
|
title: InternVL
|
||||||
- local: model_doc/janus
|
- local: model_doc/janus
|
||||||
title: Janus
|
title: Janus
|
||||||
- local: model_doc/kosmos-2
|
- local: model_doc/kosmos-2
|
||||||
|
349
docs/source/en/model_doc/internvl.md
Normal file
349
docs/source/en/model_doc/internvl.md
Normal file
@ -0,0 +1,349 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
<div style="float: right;">
|
||||||
|
<div class="flex flex-wrap space-x-1">
|
||||||
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# InternVL
|
||||||
|
|
||||||
|
The InternVL3 family of Visual Language Models was introduced in [InternVL3: Exploring Advanced Training and Test-Time Recipes for Open-Source Multimodal Models](https://huggingface.co/papers/2504.10479).
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*We introduce InternVL3, a significant advancement in the InternVL series featuring a native multimodal pre-training paradigm. Rather than adapting a text-only large language model (LLM) into a multimodal large language model (MLLM) that supports visual inputs, InternVL3 jointly acquires multimodal and linguistic capabilities from both diverse multimodal data and pure-text corpora during a single pre-training stage. This unified training paradigm effectively addresses the complexities and alignment challenges commonly encountered in conventional post-hoc training pipelines for MLLMs. To further improve performance and scalability, InternVL3 incorporates variable visual position encoding (V2PE) to support extended multimodal contexts, employs advanced post-training techniques such as supervised fine-tuning (SFT) and mixed preference optimization (MPO), and adopts test-time scaling strategies alongside an optimized training infrastructure. Extensive empirical evaluations demonstrate that InternVL3 delivers superior performance across a wide range of multi-modal tasks. In particular, InternVL3-78B achieves a score of 72.2 on the MMMU benchmark, setting a new state-of-the-art among open-source MLLMs. Its capabilities remain highly competitive with leading proprietary models, including ChatGPT-4o, Claude 3.5 Sonnet, and Gemini 2.5 Pro, while also maintaining strong pure-language proficiency. In pursuit of open-science principles, we will publicly release both the training data and model weights to foster further research and development in next-generation MLLMs.*
|
||||||
|
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/internvl_architecture.png" alt="drawing" width="600"/>
|
||||||
|
|
||||||
|
<small> Overview of InternVL3 models architecture, which is the same as InternVL2.5. Taken from the <a href="https://huggingface.co/OpenGVLab/InternVL3-1B">original checkpoint.</a> </small>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/internvl_overview_performance.png" alt="drawing" width="600"/>
|
||||||
|
|
||||||
|
<small> Comparison of InternVL3 performance on OpenCompass against other SOTA VLLMs. Taken from the <a href="https://huggingface.co/OpenGVLab/InternVL3-1B">original checkpoint.</a> </small>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
|
||||||
|
The original code can be found [here](https://github.com/OpenGVLab/InternVL).
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
|
||||||
|
### Inference with Pipeline
|
||||||
|
|
||||||
|
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `InternVL3` models in just a few lines of code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {
|
||||||
|
... "type": "image",
|
||||||
|
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||||
|
... },
|
||||||
|
... {"type": "text", "text": "Describe this image."},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> pipe = pipeline("image-text-to-text", model="OpenGVLab/InternVL3-1B-hf")
|
||||||
|
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
||||||
|
>>> outputs[0]["generated_text"]
|
||||||
|
'The image showcases a vibrant scene of nature, featuring several flowers and a bee. \n\n1. **Foreground Flowers**: \n - The primary focus is on a large, pink cosmos flower with a prominent yellow center. The petals are soft and slightly r'
|
||||||
|
```
|
||||||
|
### Inference on a single image
|
||||||
|
|
||||||
|
This example demonstrates how to perform inference on a single image with the InternVL models using chat templates.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||||
|
... {"type": "text", "text": "Please describe the image explicitly."},
|
||||||
|
... ],
|
||||||
|
... }
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=50)
|
||||||
|
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||||
|
|
||||||
|
>>> decoded_output
|
||||||
|
'The image shows two cats lying on a pink blanket. The cat on the left is a tabby with a mix of brown, black, and white fur, and it appears to be sleeping with its head resting on the blanket. The cat on the'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Text-only generation
|
||||||
|
This example shows how to generate text using the InternVL model without providing any image input.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "text", "text": "Write a haiku"},
|
||||||
|
... ],
|
||||||
|
... }
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=50)
|
||||||
|
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||||
|
|
||||||
|
>>> print(decoded_output)
|
||||||
|
"Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batched image and text inputs
|
||||||
|
InternVL models also support batched image and text inputs.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
... {"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
... {"type": "text", "text": "Describe this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... ]
|
||||||
|
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||||
|
|
||||||
|
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
>>> decoded_outputs
|
||||||
|
["user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace.",
|
||||||
|
'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of']
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batched multi-image input
|
||||||
|
This implementation of the InternVL models supports batched text-images inputs with different number of images for each text.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
... {"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||||
|
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||||
|
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
>>> ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||||
|
|
||||||
|
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
>>> decoded_outputs
|
||||||
|
["user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace.",
|
||||||
|
'user\n\n\nThese images depict two different landmarks. Can you identify them?\nassistant\nYes, these images depict the Statue of Liberty and the Golden Gate Bridge.']
|
||||||
|
```
|
||||||
|
|
||||||
|
### Video input
|
||||||
|
InternVL models can also handle video inputs. Here is an example of how to perform inference on a video input using chat templates.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||||
|
|
||||||
|
>>> model_checkpoint = "OpenGVLab/InternVL3-8B-hf"
|
||||||
|
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, quantization_config=quantization_config)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {
|
||||||
|
... "type": "video",
|
||||||
|
... "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||||
|
... },
|
||||||
|
... {"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
... ],
|
||||||
|
... }
|
||||||
|
>>> ]
|
||||||
|
>>> inputs = processor.apply_chat_template(
|
||||||
|
... messages,
|
||||||
|
... return_tensors="pt",
|
||||||
|
... add_generation_prompt=True,
|
||||||
|
... tokenize=True,
|
||||||
|
... return_dict=True,
|
||||||
|
>>> ).to(model.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||||
|
|
||||||
|
>>> decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||||
|
>>> decoded_output
|
||||||
|
'The man is performing a forehand shot.'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interleaved image and video inputs
|
||||||
|
This example showcases how to handle a batch of chat conversations with interleaved image and video inputs using chat template.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||||
|
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||||
|
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "video", "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"},
|
||||||
|
... {"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||||
|
... {"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ],
|
||||||
|
>>> ]
|
||||||
|
>>> inputs = processor.apply_chat_template(
|
||||||
|
... messages,
|
||||||
|
... padding=True,
|
||||||
|
... add_generation_prompt=True,
|
||||||
|
... tokenize=True,
|
||||||
|
... return_dict=True,
|
||||||
|
... return_tensors="pt",
|
||||||
|
>>> ).to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
>>> outputs = model.generate(**inputs, max_new_tokens=25)
|
||||||
|
|
||||||
|
>>> decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
>>> decoded_outputs
|
||||||
|
['user\n\n\nThese images depict two different landmarks. Can you identify them?\nassistant\nThe images depict the Statue of Liberty and the Golden Gate Bridge.',
|
||||||
|
'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot',
|
||||||
|
"user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace."]
|
||||||
|
```
|
||||||
|
|
||||||
|
## InternVLVisionConfig
|
||||||
|
|
||||||
|
[[autodoc]] InternVLVisionConfig
|
||||||
|
|
||||||
|
## InternVLConfig
|
||||||
|
|
||||||
|
[[autodoc]] InternVLConfig
|
||||||
|
|
||||||
|
## InternVLVisionModel
|
||||||
|
|
||||||
|
[[autodoc]] InternVLVisionModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## InternVLForConditionalGeneration
|
||||||
|
|
||||||
|
[[autodoc]] InternVLForConditionalGeneration
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## InternVLProcessor
|
||||||
|
|
||||||
|
[[autodoc]] InternVLProcessor
|
@ -18,7 +18,7 @@ from collections.abc import Iterable
|
|||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -77,8 +77,7 @@ if is_vision_available():
|
|||||||
pil_torch_interpolation_mapping = {}
|
pil_torch_interpolation_mapping = {}
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if is_torch_available():
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -162,6 +161,15 @@ def is_valid_list_of_images(images: list):
|
|||||||
return images and all(is_valid_image(image) for image in images)
|
return images and all(is_valid_image(image) for image in images)
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_list(input_list):
|
||||||
|
if isinstance(input_list[0], list):
|
||||||
|
return [item for sublist in input_list for item in sublist]
|
||||||
|
elif isinstance(input_list[0], np.ndarray):
|
||||||
|
return np.concatenate(input_list, axis=0)
|
||||||
|
elif isinstance(input_list[0], torch.Tensor):
|
||||||
|
return torch.cat(input_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def valid_images(imgs):
|
def valid_images(imgs):
|
||||||
# If we have an list of images, make sure every image is valid
|
# If we have an list of images, make sure every image is valid
|
||||||
if isinstance(imgs, (list, tuple)):
|
if isinstance(imgs, (list, tuple)):
|
||||||
|
@ -143,6 +143,7 @@ if TYPE_CHECKING:
|
|||||||
from .informer import *
|
from .informer import *
|
||||||
from .instructblip import *
|
from .instructblip import *
|
||||||
from .instructblipvideo import *
|
from .instructblipvideo import *
|
||||||
|
from .internvl import *
|
||||||
from .jamba import *
|
from .jamba import *
|
||||||
from .janus import *
|
from .janus import *
|
||||||
from .jetmoe import *
|
from .jetmoe import *
|
||||||
|
@ -162,6 +162,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("informer", "InformerConfig"),
|
("informer", "InformerConfig"),
|
||||||
("instructblip", "InstructBlipConfig"),
|
("instructblip", "InstructBlipConfig"),
|
||||||
("instructblipvideo", "InstructBlipVideoConfig"),
|
("instructblipvideo", "InstructBlipVideoConfig"),
|
||||||
|
("internvl", "InternVLConfig"),
|
||||||
|
("internvl_vision", "InternVLVisionConfig"),
|
||||||
("jamba", "JambaConfig"),
|
("jamba", "JambaConfig"),
|
||||||
("janus", "JanusConfig"),
|
("janus", "JanusConfig"),
|
||||||
("jetmoe", "JetMoeConfig"),
|
("jetmoe", "JetMoeConfig"),
|
||||||
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("informer", "Informer"),
|
("informer", "Informer"),
|
||||||
("instructblip", "InstructBLIP"),
|
("instructblip", "InstructBLIP"),
|
||||||
("instructblipvideo", "InstructBlipVideo"),
|
("instructblipvideo", "InstructBlipVideo"),
|
||||||
|
("internvl", "InternVL"),
|
||||||
|
("internvl_vision", "InternVLVision"),
|
||||||
("jamba", "Jamba"),
|
("jamba", "Jamba"),
|
||||||
("janus", "Janus"),
|
("janus", "Janus"),
|
||||||
("jetmoe", "JetMoe"),
|
("jetmoe", "JetMoe"),
|
||||||
@ -797,6 +801,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
|||||||
("chinese_clip_vision_model", "chinese_clip"),
|
("chinese_clip_vision_model", "chinese_clip"),
|
||||||
("rt_detr_resnet", "rt_detr"),
|
("rt_detr_resnet", "rt_detr"),
|
||||||
("granitevision", "llava_next"),
|
("granitevision", "llava_next"),
|
||||||
|
("internvl_vision", "internvl"),
|
||||||
("qwen2_5_vl_text", "qwen2_5_vl"),
|
("qwen2_5_vl_text", "qwen2_5_vl"),
|
||||||
("qwen2_vl_text", "qwen2_vl"),
|
("qwen2_vl_text", "qwen2_vl"),
|
||||||
("sam_vision_model", "sam"),
|
("sam_vision_model", "sam"),
|
||||||
|
@ -151,6 +151,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("ijepa", "IJepaModel"),
|
("ijepa", "IJepaModel"),
|
||||||
("imagegpt", "ImageGPTModel"),
|
("imagegpt", "ImageGPTModel"),
|
||||||
("informer", "InformerModel"),
|
("informer", "InformerModel"),
|
||||||
|
("internvl_vision", "InternVLVisionModel"),
|
||||||
("jamba", "JambaModel"),
|
("jamba", "JambaModel"),
|
||||||
("janus", "JanusModel"),
|
("janus", "JanusModel"),
|
||||||
("jetmoe", "JetMoeModel"),
|
("jetmoe", "JetMoeModel"),
|
||||||
@ -862,6 +863,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
|||||||
("idefics2", "Idefics2ForConditionalGeneration"),
|
("idefics2", "Idefics2ForConditionalGeneration"),
|
||||||
("idefics3", "Idefics3ForConditionalGeneration"),
|
("idefics3", "Idefics3ForConditionalGeneration"),
|
||||||
("instructblip", "InstructBlipForConditionalGeneration"),
|
("instructblip", "InstructBlipForConditionalGeneration"),
|
||||||
|
("internvl", "InternVLForConditionalGeneration"),
|
||||||
("janus", "JanusForConditionalGeneration"),
|
("janus", "JanusForConditionalGeneration"),
|
||||||
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
||||||
("llama4", "Llama4ForConditionalGeneration"),
|
("llama4", "Llama4ForConditionalGeneration"),
|
||||||
|
@ -75,6 +75,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("idefics3", "Idefics3Processor"),
|
("idefics3", "Idefics3Processor"),
|
||||||
("instructblip", "InstructBlipProcessor"),
|
("instructblip", "InstructBlipProcessor"),
|
||||||
("instructblipvideo", "InstructBlipVideoProcessor"),
|
("instructblipvideo", "InstructBlipVideoProcessor"),
|
||||||
|
("internvl", "InternVLProcessor"),
|
||||||
("janus", "JanusProcessor"),
|
("janus", "JanusProcessor"),
|
||||||
("kosmos-2", "Kosmos2Processor"),
|
("kosmos-2", "Kosmos2Processor"),
|
||||||
("layoutlmv2", "LayoutLMv2Processor"),
|
("layoutlmv2", "LayoutLMv2Processor"),
|
||||||
|
@ -258,6 +258,7 @@ else:
|
|||||||
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
(
|
(
|
||||||
"jamba",
|
"jamba",
|
||||||
(
|
(
|
||||||
|
28
src/transformers/models/internvl/__init__.py
Normal file
28
src/transformers/models/internvl/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_internvl import *
|
||||||
|
from .modeling_internvl import *
|
||||||
|
from .processing_internvl import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
225
src/transformers/models/internvl/configuration_internvl.py
Normal file
225
src/transformers/models/internvl/configuration_internvl.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`InternVLVisionModel`]. It is used to instantiate an InternVLVisionModel
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
|
||||||
|
a similar configuration to that of the InternVL3-1B.
|
||||||
|
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to add a bias to the queries, keys and values.
|
||||||
|
use_qk_norm (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to apply normalization to the queries and keys before the attention operation.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||||
|
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
Dropout probability for attention weights.
|
||||||
|
projection_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
Dropout probability for the projection layer.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||||
|
The type of normalization to use in the encoder. Can be `"layer_norm"` or `"rms_norm"`.
|
||||||
|
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
image_size (`int` or `list[int]`, *optional*, defaults to `[448, 448]`):
|
||||||
|
The size (resolution) of each image.
|
||||||
|
patch_size (`int` or `list[int]`, *optional*, defaults to `[14, 14]`):
|
||||||
|
The size (resolution) of each patch.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of input channels.
|
||||||
|
use_mask_token (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a mask token for masked image modeling.
|
||||||
|
use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use BERT-style absolute position embeddings.
|
||||||
|
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
|
||||||
|
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
|
||||||
|
use_mean_pooling (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
||||||
|
CLS token, before applying the classification head.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import InternVLVisionConfig, InternVLVisionModel
|
||||||
|
|
||||||
|
>>> # Initializing a InternVLVisionModel OpenGVLab/InternVL3-1B-hf style configuration
|
||||||
|
>>> configuration = InternVLVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
|
||||||
|
>>> model = InternVLVisionModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "internvl_vision"
|
||||||
|
base_config_key = "vision_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
attention_bias=False,
|
||||||
|
use_qk_norm=False,
|
||||||
|
intermediate_size=4096,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
projection_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
norm_type="layer_norm",
|
||||||
|
layer_norm_eps=1e-06,
|
||||||
|
image_size=[448, 448],
|
||||||
|
patch_size=[14, 14],
|
||||||
|
num_channels=3,
|
||||||
|
use_mask_token=False,
|
||||||
|
use_absolute_position_embeddings=True,
|
||||||
|
layer_scale_init_value=0.1,
|
||||||
|
use_mean_pooling=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.projection_dropout = projection_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
|
||||||
|
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
|
||||||
|
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.use_mask_token = use_mask_token
|
||||||
|
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
||||||
|
self.layer_scale_init_value = layer_scale_init_value
|
||||||
|
self.use_mean_pooling = use_mean_pooling
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`InternVLForConditionalGeneration`]. It is used to instantiate a
|
||||||
|
InternVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of InternVL3-1B.
|
||||||
|
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `InternVisonConfig`):
|
||||||
|
The config object or dictionary of the vision backbone.
|
||||||
|
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
||||||
|
The config object or dictionary of the text backbone.
|
||||||
|
image_token_id (`int`, *optional*, defaults to 151667):
|
||||||
|
The image token index to encode the image prompt.
|
||||||
|
image_seq_length (`int`, *optional*, defaults to 256):
|
||||||
|
Number of image tokens to use per image patch.
|
||||||
|
downsample_ratio (`float`, *optional*, defaults to 0.5):
|
||||||
|
Factor by which to downsample the image.
|
||||||
|
projector_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the projector.
|
||||||
|
vision_feature_layer (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the layer to use as the image features.
|
||||||
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||||
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
|
Can be one of `"default"` or `"full"`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import InternVLForConditionalGeneration, InternVLConfig
|
||||||
|
|
||||||
|
>>> # Initializing a InternVL style configuration
|
||||||
|
>>> configuration = InternVLConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
|
||||||
|
>>> model = InternVLForConditionalGeneration(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "internvl"
|
||||||
|
sub_configs = {"text_config": AutoConfig, "vision_config": InternVLVisionConfig}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config=None,
|
||||||
|
text_config=None,
|
||||||
|
image_token_id=151667,
|
||||||
|
image_seq_length=256,
|
||||||
|
downsample_ratio=0.5,
|
||||||
|
projector_hidden_act="gelu",
|
||||||
|
vision_feature_layer=-1,
|
||||||
|
vision_feature_select_strategy="default",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.image_seq_length = image_seq_length
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.vision_feature_layer = vision_feature_layer
|
||||||
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
self.vision_config = InternVLVisionConfig(**vision_config)
|
||||||
|
elif isinstance(vision_config, InternVLVisionConfig):
|
||||||
|
self.vision_config = vision_config
|
||||||
|
elif vision_config is None:
|
||||||
|
self.vision_config = InternVLVisionConfig()
|
||||||
|
|
||||||
|
if isinstance(text_config, dict):
|
||||||
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
|
||||||
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||||
|
elif text_config is None:
|
||||||
|
text_config = CONFIG_MAPPING["qwen2"]()
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["InternVLVisionConfig", "InternVLConfig"]
|
@ -0,0 +1,417 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
GotOcr2ImageProcessorFast,
|
||||||
|
InternVLConfig,
|
||||||
|
InternVLForConditionalGeneration,
|
||||||
|
InternVLProcessor,
|
||||||
|
InternVLVisionConfig,
|
||||||
|
LlamaConfig,
|
||||||
|
Qwen2Config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LM_TYPE_CORRESPONDENCE = {
|
||||||
|
"OpenGVLab/InternVL2_5-1B-MPO": "qwen2",
|
||||||
|
"OpenGVLab/InternVL2_5-2B-MPO": "llama",
|
||||||
|
"OpenGVLab/InternVL2_5-4B-MPO": "qwen2",
|
||||||
|
"OpenGVLab/InternVL2_5-8B-MPO": "llama",
|
||||||
|
"OpenGVLab/InternVL2_5-26B-MPO": "llama",
|
||||||
|
"OpenGVLab/InternVL2_5-38B-MPO": "qwen2",
|
||||||
|
"OpenGVLab/InternVL2_5-78B-MPO": "qwen2",
|
||||||
|
"OpenGVLab/InternVL3-1B": "qwen2",
|
||||||
|
"OpenGVLab/InternVL3-2B": "qwen2",
|
||||||
|
"OpenGVLab/InternVL3-8B": "qwen2",
|
||||||
|
"OpenGVLab/InternVL3-9B": "llama",
|
||||||
|
"OpenGVLab/InternVL3-14B": "qwen2",
|
||||||
|
"OpenGVLab/InternVL3-38B": "qwen2",
|
||||||
|
"OpenGVLab/InternVL3-78B": "qwen2",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNNECESSARY_CONFIG_KEYS = [ "_name_or_path", "_attn_implementation_autoset", "auto_map", "use_bfloat16", "use_flash_attn", "bias", "laux_allreduce", "moe_coeff_ratio", "moe_intermediate_size", "moe_output_scale", "noisy_gate_policy", "shared_expert_intermediate_size", "use_residual", "use_moe", "use_rts", "use_weighted_residual", "moe_config", "num_experts", "num_routed_experts", "num_shared_experts", "capacity_factor", "eval_capacity_factor", "drop_path_rate"] # fmt: skip
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION = {
|
||||||
|
# Vision encoder mapping
|
||||||
|
r"vision_model": r"vision_tower",
|
||||||
|
r"layers": r"layer",
|
||||||
|
r"class_embedding": r"cls_token",
|
||||||
|
r"position_embedding": r"position_embeddings",
|
||||||
|
r"patch_embedding": r"patch_embeddings.projection",
|
||||||
|
r"ls(\d+)": r"lambda_\1",
|
||||||
|
r"attn.proj": r"attention.projection_layer",
|
||||||
|
r"attn.dropout": r"attention.projection_dropout",
|
||||||
|
r"attn": r"attention",
|
||||||
|
r"norm1": r"layernorm_before",
|
||||||
|
r"norm2": r"layernorm_after",
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA = {
|
||||||
|
# Vision encoder mapping
|
||||||
|
r"tok_embeddings": r"embed_tokens",
|
||||||
|
r"attention.wo": r"self_attn.o_proj",
|
||||||
|
r"feed_forward.w1": r"mlp.gate_proj",
|
||||||
|
r"feed_forward.w2": r"mlp.down_proj",
|
||||||
|
r"feed_forward.w3": r"mlp.up_proj",
|
||||||
|
r"attention_norm": r"input_layernorm",
|
||||||
|
r"ffn_norm": r"post_attention_layernorm",
|
||||||
|
r"output": r"lm_head",
|
||||||
|
}
|
||||||
|
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI = {
|
||||||
|
# Vision encoder mapping
|
||||||
|
r"mlp1.0": r"multi_modal_projector.layer_norm",
|
||||||
|
r"mlp1.1": r"multi_modal_projector.linear_1",
|
||||||
|
r"mlp1.3": r"multi_modal_projector.linear_2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
chat_template = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||||
|
"{% if message['content'] is string %}"
|
||||||
|
"{{ message['content'] }}"
|
||||||
|
"{% else %}"
|
||||||
|
"{% for content in message['content'] %}"
|
||||||
|
"{% if content['type'] == 'image' %}"
|
||||||
|
"{{ '<image>\n' }}"
|
||||||
|
"{% elif content['type'] == 'video' %}"
|
||||||
|
"{{ '<video>\n' }}"
|
||||||
|
"{% elif content['type'] == 'text' %}"
|
||||||
|
"{{ content['text'] }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{{'<|im_end|>\n'}}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"{{'<|im_start|>assistant\n' }}"
|
||||||
|
"{% endif %}"
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
CONTEXT_LENGTH = 8192
|
||||||
|
|
||||||
|
|
||||||
|
def convert_old_keys_to_new_keys(state_dict_keys: dict = None, path: str = None):
|
||||||
|
"""
|
||||||
|
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||||
|
the key mappings.
|
||||||
|
"""
|
||||||
|
output_dict = {}
|
||||||
|
if state_dict_keys is not None:
|
||||||
|
old_text_vision = "\n".join([key for key in state_dict_keys if key.startswith("vision_model")])
|
||||||
|
new_text = old_text_vision
|
||||||
|
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION.items():
|
||||||
|
new_text = re.sub(pattern, replacement, new_text)
|
||||||
|
output_dict = dict(zip(old_text_vision.split("\n"), new_text.split("\n")))
|
||||||
|
old_text_language = "\n".join([key for key in state_dict_keys if key.startswith("language_model")])
|
||||||
|
new_text = old_text_language
|
||||||
|
if LM_TYPE_CORRESPONDENCE[path] == "llama":
|
||||||
|
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA.items():
|
||||||
|
new_text = re.sub(pattern, replacement, new_text)
|
||||||
|
output_dict.update(dict(zip(old_text_language.split("\n"), new_text.split("\n"))))
|
||||||
|
old_text_multi = "\n".join(
|
||||||
|
[
|
||||||
|
key
|
||||||
|
for key in state_dict_keys
|
||||||
|
if not (key.startswith("language_model") or key.startswith("vision_model"))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
new_text = old_text_multi
|
||||||
|
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI.items():
|
||||||
|
new_text = re.sub(pattern, replacement, new_text)
|
||||||
|
output_dict.update(dict(zip(old_text_multi.split("\n"), new_text.split("\n"))))
|
||||||
|
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_original_state_dict(input_base_path):
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
input_base_path,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
use_flash_attn=False,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def get_internvl_config(input_base_path):
|
||||||
|
base_config = AutoModel.from_pretrained(input_base_path, trust_remote_code=True).config
|
||||||
|
llm_config = base_config.llm_config.to_dict()
|
||||||
|
vision_config = base_config.vision_config.to_dict()
|
||||||
|
vision_config["use_absolute_position_embeddings"] = True
|
||||||
|
if LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
|
||||||
|
image_token_id = 151667
|
||||||
|
language_config_class = Qwen2Config
|
||||||
|
else:
|
||||||
|
image_token_id = 92546
|
||||||
|
language_config_class = LlamaConfig
|
||||||
|
|
||||||
|
llm_config = {k: v for k, v in llm_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||||
|
# Force use_cache to True
|
||||||
|
llm_config["use_cache"] = True
|
||||||
|
# Force correct eos_token_id for InternVL3
|
||||||
|
if "InternVL3" in input_base_path and LM_TYPE_CORRESPONDENCE[input_base_path] == "qwen2":
|
||||||
|
llm_config["eos_token_id"] = 151645
|
||||||
|
|
||||||
|
vision_config = {k: v for k, v in vision_config.items() if k not in UNNECESSARY_CONFIG_KEYS}
|
||||||
|
if "attention_probs_dropout_prob" in vision_config:
|
||||||
|
attention_dropout = vision_config.pop("attention_probs_dropout_prob")
|
||||||
|
vision_config["attention_dropout"] = attention_dropout
|
||||||
|
vision_config["projection_dropout"] = attention_dropout
|
||||||
|
if "qk_normalization" in vision_config:
|
||||||
|
use_qk_norm = vision_config.pop("qk_normalization")
|
||||||
|
vision_config["use_qk_norm"] = use_qk_norm
|
||||||
|
if "qkv_bias" in vision_config:
|
||||||
|
attention_bias = vision_config.pop("qkv_bias")
|
||||||
|
vision_config["attention_bias"] = attention_bias
|
||||||
|
|
||||||
|
return InternVLConfig(
|
||||||
|
text_config=language_config_class(**llm_config),
|
||||||
|
vision_config=InternVLVisionConfig(**vision_config),
|
||||||
|
image_token_id=image_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_model(
|
||||||
|
model_path,
|
||||||
|
input_base_path,
|
||||||
|
push_to_hub=False,
|
||||||
|
hub_dir=None,
|
||||||
|
):
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
|
||||||
|
config = get_internvl_config(input_base_path)
|
||||||
|
config.architectures = ["InternVLForConditionalGeneration"]
|
||||||
|
config.save_pretrained(model_path)
|
||||||
|
if push_to_hub:
|
||||||
|
config.push_to_hub(hub_dir, use_temp_dir=True)
|
||||||
|
print("Model config saved successfully...")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Convert weights
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
|
||||||
|
state_dict_old = load_original_state_dict(input_base_path)
|
||||||
|
print("Converting model...")
|
||||||
|
all_keys = list(state_dict_old.keys())
|
||||||
|
new_keys = convert_old_keys_to_new_keys(all_keys, path=input_base_path)
|
||||||
|
lm_dim = config.text_config.hidden_size
|
||||||
|
dim = config.vision_config.hidden_size
|
||||||
|
state_dict = {}
|
||||||
|
for key in all_keys:
|
||||||
|
new_key = new_keys[key]
|
||||||
|
if "attn.qkv" in key:
|
||||||
|
new_key_query = new_key.replace("attention.qkv", "attention.q_proj")
|
||||||
|
state_dict[new_key_query] = state_dict_old[key][:dim]
|
||||||
|
|
||||||
|
new_key_key = new_key.replace("attention.qkv", "attention.k_proj")
|
||||||
|
state_dict[new_key_key] = state_dict_old[key][dim : 2 * dim]
|
||||||
|
|
||||||
|
new_key_value = new_key.replace("attention.qkv", "attention.v_proj")
|
||||||
|
state_dict[new_key_value] = state_dict_old[key][-dim:]
|
||||||
|
elif "attention.wqkv" in key:
|
||||||
|
num_key_value_groups = config.text_config.num_attention_heads // config.text_config.num_key_value_heads
|
||||||
|
head_dim = config.text_config.head_dim
|
||||||
|
wqkv_weights = state_dict_old[key]
|
||||||
|
|
||||||
|
qkv_vecs = rearrange(wqkv_weights, "(h gs d) z -> h gs d z", gs=2 + num_key_value_groups, d=head_dim)
|
||||||
|
q_proj = qkv_vecs[:, :num_key_value_groups, ...].reshape(-1, lm_dim).contiguous()
|
||||||
|
k_proj = qkv_vecs[:, -2, ...].reshape(-1, lm_dim).contiguous()
|
||||||
|
v_proj = qkv_vecs[:, -1, ...].reshape(-1, lm_dim).contiguous()
|
||||||
|
|
||||||
|
new_key_query = new_key.replace("attention.wqkv", "self_attn.q_proj")
|
||||||
|
state_dict[new_key_query] = q_proj
|
||||||
|
|
||||||
|
new_key_key = new_key.replace("attention.wqkv", "self_attn.k_proj")
|
||||||
|
state_dict[new_key_key] = k_proj
|
||||||
|
|
||||||
|
new_key_value = new_key.replace("attention.wqkv", "self_attn.v_proj")
|
||||||
|
state_dict[new_key_value] = v_proj
|
||||||
|
else:
|
||||||
|
state_dict[new_key] = state_dict_old[key]
|
||||||
|
|
||||||
|
del state_dict_old
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print("Loading the checkpoint in a InternVLForConditionalGeneration model.")
|
||||||
|
model = InternVLForConditionalGeneration(config)
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
|
model = model.to(torch.bfloat16)
|
||||||
|
print("model dtype:", model.dtype)
|
||||||
|
print("Missing keys:", missing_keys)
|
||||||
|
print("Unexpected keys:", unexpected_keys)
|
||||||
|
|
||||||
|
print("Saving the model.")
|
||||||
|
model.save_pretrained(model_path)
|
||||||
|
if push_to_hub:
|
||||||
|
model.push_to_hub(hub_dir, use_temp_dir=True)
|
||||||
|
|
||||||
|
image_processor = GotOcr2ImageProcessorFast.from_pretrained(model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
processor = InternVLProcessor(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
|
||||||
|
processor.save_pretrained(model_path)
|
||||||
|
if push_to_hub:
|
||||||
|
processor.push_to_hub(hub_dir, use_temp_dir=True)
|
||||||
|
|
||||||
|
# generation config
|
||||||
|
if LM_TYPE_CORRESPONDENCE[input_base_path] == "llama":
|
||||||
|
print("Saving generation config...")
|
||||||
|
# in the original model, eos_token is not the same in the text_config and the generation_config
|
||||||
|
# ("</s>" - 2 in the text_config and "<|im_end|>" - 92542 in the generation_config)
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
eos_token_id=92542,
|
||||||
|
)
|
||||||
|
generation_config.save_pretrained(model_path)
|
||||||
|
if push_to_hub:
|
||||||
|
generation_config.push_to_hub(hub_dir, use_temp_dir=True)
|
||||||
|
|
||||||
|
# del state_dict, model
|
||||||
|
|
||||||
|
# # Safety check: reload the converted model
|
||||||
|
gc.collect()
|
||||||
|
print("Reloading the model to check if it's saved correctly.")
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
|
||||||
|
print("Model reloaded successfully.")
|
||||||
|
del model
|
||||||
|
|
||||||
|
|
||||||
|
def write_tokenizer(save_dir: str, push_to_hub: bool = False, path: str = None, hub_dir: str = None):
|
||||||
|
if LM_TYPE_CORRESPONDENCE[path] == "qwen2":
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||||
|
return_token_type_ids=False,
|
||||||
|
extra_special_tokens={
|
||||||
|
"start_image_token": "<img>",
|
||||||
|
"end_image_token": "</img>",
|
||||||
|
"context_image_token": "<IMG_CONTEXT>",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokenizer.model_max_length = CONTEXT_LENGTH
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<img>",
|
||||||
|
"</img>",
|
||||||
|
"<IMG_CONTEXT>",
|
||||||
|
"<quad>",
|
||||||
|
"</quad>",
|
||||||
|
"<ref>",
|
||||||
|
"</ref>",
|
||||||
|
"<box>",
|
||||||
|
"</box>",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
replace_additional_special_tokens=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Obtained with:
|
||||||
|
# tokenizer_llama_fast = LlamaTokenizerFast.from_pretrained(
|
||||||
|
# "OpenGVLab/InternVL2_5-2B-MPO", pad_token="</s>", legacy=False, from_slow=True
|
||||||
|
# )
|
||||||
|
# tokenizer_llama_fast._tokenizer.pre_tokenizer.prepend_scheme = "never"
|
||||||
|
# Then manually modifying `added_tokens_decoder` indices to match the original tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"./intern_vl_hf_implem/tokenizer_internvl_llama_fast",
|
||||||
|
return_token_type_ids=False,
|
||||||
|
extra_special_tokens={
|
||||||
|
"start_image_token": "<img>",
|
||||||
|
"end_image_token": "</img>",
|
||||||
|
"context_image_token": "<IMG_CONTEXT>",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.chat_template = chat_template
|
||||||
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
if push_to_hub:
|
||||||
|
tokenizer.push_to_hub(hub_dir, use_temp_dir=True)
|
||||||
|
|
||||||
|
|
||||||
|
def write_image_processor(save_dir: str, push_to_hub: bool = False, hub_dir: str = None):
|
||||||
|
image_processor = GotOcr2ImageProcessorFast(
|
||||||
|
do_resize=True,
|
||||||
|
size={"height": 448, "width": 448},
|
||||||
|
do_rescale=True,
|
||||||
|
rescale_factor=1 / 255,
|
||||||
|
do_normalize=True,
|
||||||
|
image_mean=[0.485, 0.456, 0.406],
|
||||||
|
image_std=[0.229, 0.224, 0.225],
|
||||||
|
do_convert_rgb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_processor.save_pretrained(save_dir)
|
||||||
|
if push_to_hub:
|
||||||
|
image_processor.push_to_hub(hub_dir, use_temp_dir=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_dir",
|
||||||
|
default="OpenGVLab/InternVL3-1B",
|
||||||
|
help="Location of original InternVL model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default="InternVL3-1B-hf",
|
||||||
|
help="Location to write HF model and processors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_dir",
|
||||||
|
default="OpenGVLab/InternVL3-1B-hf",
|
||||||
|
help="Location to write HF model and processors",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
write_tokenizer(
|
||||||
|
save_dir=args.output_dir,
|
||||||
|
push_to_hub=args.push_to_hub,
|
||||||
|
path=args.input_dir,
|
||||||
|
hub_dir=args.hub_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
write_image_processor(
|
||||||
|
save_dir=args.output_dir,
|
||||||
|
push_to_hub=args.push_to_hub,
|
||||||
|
hub_dir=args.hub_dir,
|
||||||
|
)
|
||||||
|
write_model(
|
||||||
|
model_path=args.output_dir,
|
||||||
|
input_base_path=args.input_dir,
|
||||||
|
push_to_hub=args.push_to_hub,
|
||||||
|
hub_dir=args.hub_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1094
src/transformers/models/internvl/modeling_internvl.py
Normal file
1094
src/transformers/models/internvl/modeling_internvl.py
Normal file
File diff suppressed because it is too large
Load Diff
708
src/transformers/models/internvl/modular_internvl.py
Normal file
708
src/transformers/models/internvl/modular_internvl.py
Normal file
@ -0,0 +1,708 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
can_return_tuple,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
torch_int,
|
||||||
|
)
|
||||||
|
from ..clip.modeling_clip import CLIPMLP
|
||||||
|
from ..janus.modeling_janus import JanusVisionAttention
|
||||||
|
from ..llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
|
||||||
|
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
|
||||||
|
_CONFIG_VISION_FOR_DOC = "InternVLVisionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
key_states = key
|
||||||
|
value_states = value
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# No upcasting of the attention weights to float32 in this implementation
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionRMSNorm(LlamaRMSNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionAttention(JanusVisionAttention):
|
||||||
|
def __init__(self, config: InternVLVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
del self.num_key_value_groups
|
||||||
|
|
||||||
|
# Needed for flash attention
|
||||||
|
self.is_causal = False
|
||||||
|
qk_norm = config.use_qk_norm
|
||||||
|
|
||||||
|
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = InternVLVisionConfig
|
||||||
|
base_model_prefix = "internvl_vision"
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["InternVLVisionLayer"]
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
elif isinstance(module, InternVLVisionEmbeddings):
|
||||||
|
module.cls_token.data.zero_()
|
||||||
|
if module.mask_token is not None:
|
||||||
|
module.mask_token.data.zero_()
|
||||||
|
if module.position_embeddings is not None:
|
||||||
|
module.position_embeddings.data.zero_()
|
||||||
|
elif isinstance(module, InternVLVisionLayer):
|
||||||
|
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
|
||||||
|
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
|
||||||
|
"""
|
||||||
|
Class for outputs of [`InternVLVisionModel`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
||||||
|
Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
|
||||||
|
*config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
|
||||||
|
will be returned.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionPatchEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||||||
|
Transformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
image_size, patch_size = config.image_size, config.patch_size
|
||||||
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||||||
|
|
||||||
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||||
|
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_patches = num_patches
|
||||||
|
self.patch_shape = patch_shape
|
||||||
|
|
||||||
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
if num_channels != self.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = self.projection(pixel_values)
|
||||||
|
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
||||||
|
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
return embeddings, (patch_height, patch_width)
|
||||||
|
|
||||||
|
|
||||||
|
# Based on timm implementation, which can be found here:
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||||
|
class InternVLVisionEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
|
if config.use_mask_token:
|
||||||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
|
else:
|
||||||
|
self.mask_token = None
|
||||||
|
self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.image_size = (
|
||||||
|
config.image_size
|
||||||
|
if isinstance(config.image_size, collections.abc.Iterable)
|
||||||
|
else (config.image_size, config.image_size)
|
||||||
|
)
|
||||||
|
num_patches = self.patch_embeddings.num_patches
|
||||||
|
if config.use_absolute_position_embeddings:
|
||||||
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
|
else:
|
||||||
|
self.position_embeddings = None
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||||
|
images. This method is also adapted to support torch.jit tracing.
|
||||||
|
|
||||||
|
Adapted from:
|
||||||
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||||
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
|
||||||
|
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||||
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
|
||||||
|
class_pos_embed = self.position_embeddings[:, :1]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
|
||||||
|
new_height = height // self.patch_size
|
||||||
|
new_width = width // self.patch_size
|
||||||
|
|
||||||
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
size=(new_height, new_width),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
|
||||||
|
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
|
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
|
||||||
|
batch_size, seq_len, _ = embeddings.size()
|
||||||
|
|
||||||
|
if bool_masked_pos is not None:
|
||||||
|
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
||||||
|
# replace the masked visual tokens by mask_tokens
|
||||||
|
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||||
|
embeddings = embeddings * (1 - w) + mask_tokens * w
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||||
|
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||||
|
|
||||||
|
if self.position_embeddings is not None:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
|
||||||
|
return embeddings, (patch_height, patch_width)
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionMLP(CLIPMLP):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionLayer(nn.Module):
|
||||||
|
"""This corresponds to the Block class in the timm implementation."""
|
||||||
|
|
||||||
|
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
self.attention = InternVLVisionAttention(config)
|
||||||
|
self.mlp = InternVLVisionMLP(config)
|
||||||
|
# InternVL uses different layernorm implementations for different models
|
||||||
|
self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
init_values = config.layer_scale_init_value
|
||||||
|
self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
||||||
|
self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
attention_output, attention_weights = self.attention(
|
||||||
|
self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_output = self.lambda_1 * attention_output
|
||||||
|
|
||||||
|
# first residual connection
|
||||||
|
hidden_states = attention_output + hidden_states
|
||||||
|
|
||||||
|
# in InternVLVision, layernorm is also applied after self-attention
|
||||||
|
layer_output = self.layernorm_after(hidden_states)
|
||||||
|
|
||||||
|
layer_output = self.mlp(layer_output)
|
||||||
|
layer_output = self.dropout(layer_output)
|
||||||
|
|
||||||
|
if self.lambda_2 is not None:
|
||||||
|
layer_output = self.lambda_2 * layer_output
|
||||||
|
|
||||||
|
# second residual connection
|
||||||
|
layer_output = layer_output + hidden_states
|
||||||
|
|
||||||
|
return layer_output, attention_weights
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionEncoder(nn.Module):
|
||||||
|
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
layer_module.__call__, hidden_states, output_attentions
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(hidden_states, output_attentions)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
||||||
|
|
||||||
|
_CONFIG_VISION_FOR_DOC = "InternVLVisionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
INTERNVL_VISION_START_DOCSTRING = r"""
|
||||||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||||
|
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||||
|
behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`InternVLVisionConfig`]): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERNVL_VISION_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||||||
|
[`InternVLVisionImageProcessor.__call__`] for details.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare InternVLVision Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
INTERNVL_VISION_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class InternVLVisionModel(InternVLVisionPreTrainedModel):
|
||||||
|
def __init__(self, config: InternVLVisionConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = InternVLVisionEmbeddings(config)
|
||||||
|
self.encoder = InternVLVisionEncoder(config)
|
||||||
|
|
||||||
|
self.layernorm = (
|
||||||
|
nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@add_start_docstrings_to_model_forward(INTERNVL_VISION_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=InternVLVisionModelOutputWithPooling,
|
||||||
|
config_class=_CONFIG_VISION_FOR_DOC,
|
||||||
|
modality="vision",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
|
||||||
|
r"""
|
||||||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
sequence_output = self.layernorm(sequence_output)
|
||||||
|
|
||||||
|
return InternVLVisionModelOutputWithPooling(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "InternVLConfig"
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLPreTrainedModel(LlavaPreTrainedModel):
|
||||||
|
def _init_weights(self, module):
|
||||||
|
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
INTERNVL_INPUTS_DOCSTRING = None
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLMultiModalProjector(nn.Module):
|
||||||
|
def __init__(self, config: InternVLConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
|
config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
|
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
|
||||||
|
|
||||||
|
def forward(self, image_features):
|
||||||
|
hidden_states = self.layer_norm(image_features)
|
||||||
|
hidden_states = self.linear_1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
|
||||||
|
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
|
||||||
|
"""Perform pixel shuffle downsampling on vision features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_features (`torch.Tensor`):
|
||||||
|
Input tensor of shape (batch_size, width, height, channels).
|
||||||
|
scale_factor (`float`, *optional*, defaults to `0.5`):
|
||||||
|
Factor by which to downsample. Default is 0.5, which halves the dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
vision_features (`torch.Tensor`):
|
||||||
|
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
|
||||||
|
"""
|
||||||
|
batch_size, width, height, channels = vision_features.size()
|
||||||
|
|
||||||
|
if height % scale_factor != 0 or width % scale_factor != 0:
|
||||||
|
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
|
||||||
|
|
||||||
|
# Reshape to allow downsampling
|
||||||
|
vision_features = vision_features.view(
|
||||||
|
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
|
||||||
|
)
|
||||||
|
# Permute dimensions to align downsampled axis correctly
|
||||||
|
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
|
||||||
|
|
||||||
|
# Reshape to achieve final downsampled dimensions
|
||||||
|
vision_features = vision_features.view(
|
||||||
|
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Swap height and width back for proper orientation
|
||||||
|
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
|
||||||
|
|
||||||
|
return vision_features
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
vision_feature_layer: Union[int, List[int]],
|
||||||
|
vision_feature_select_strategy: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
vision_feature_layer (`int` or `List[int]`):
|
||||||
|
Layer index or list of layer indices to extract features from.
|
||||||
|
Returns:
|
||||||
|
vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
||||||
|
"""
|
||||||
|
downsample_ratio = self.config.downsample_ratio
|
||||||
|
if vision_feature_layer == -1:
|
||||||
|
vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||||
|
else:
|
||||||
|
vision_features = self.vision_model(pixel_values=pixel_values).hidden_states[vision_feature_layer]
|
||||||
|
if vision_feature_select_strategy == "default":
|
||||||
|
vision_features = vision_features[:, 1:, :]
|
||||||
|
|
||||||
|
# Calculate dimensions based on vision features
|
||||||
|
channels = vision_features.shape[1]
|
||||||
|
feature_size = int(channels**0.5)
|
||||||
|
batch_size = vision_features.shape[0]
|
||||||
|
|
||||||
|
# Reshape tensor to spatial dimensions
|
||||||
|
vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
|
||||||
|
|
||||||
|
# Apply downsampling using pixel shuffle
|
||||||
|
vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
|
||||||
|
|
||||||
|
# Reshape tensor to prepare for projection
|
||||||
|
vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
|
||||||
|
|
||||||
|
# Project features through multi-modal projector
|
||||||
|
vision_features = self.multi_modal_projector(vision_features)
|
||||||
|
|
||||||
|
return vision_features
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[int] = None,
|
||||||
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
|
||||||
|
>>> torch_device = "cuda"
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
... "OpenGVLab/InternVL3-1B-hf", torch_dtype=torch.bfloat16, device_map=torch_device
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {
|
||||||
|
... "type": "image",
|
||||||
|
... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
... },
|
||||||
|
... {
|
||||||
|
... "type": "image",
|
||||||
|
... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||||
|
... },
|
||||||
|
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||||
|
... ],
|
||||||
|
... },
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=200)
|
||||||
|
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
|
||||||
|
The images depict the Statue of Liberty and the Golden Gate Bridge.
|
||||||
|
```"""
|
||||||
|
super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InternVLVisionPreTrainedModel",
|
||||||
|
"InternVLVisionModel",
|
||||||
|
"InternVLPreTrainedModel",
|
||||||
|
"InternVLForConditionalGeneration",
|
||||||
|
]
|
378
src/transformers/models/internvl/processing_internvl.py
Normal file
378
src/transformers/models/internvl/processing_internvl.py
Normal file
@ -0,0 +1,378 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.processing_utils import (
|
||||||
|
AllKwargsForChatTemplate,
|
||||||
|
ImagesKwargs,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
)
|
||||||
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
from ...image_processing_utils import BatchFeature
|
||||||
|
from ...image_utils import (
|
||||||
|
ImageInput,
|
||||||
|
VideoInput,
|
||||||
|
VideoMetadata,
|
||||||
|
concatenate_list,
|
||||||
|
make_batched_videos,
|
||||||
|
make_flat_list_of_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLImagesKwargs(ImagesKwargs, total=False):
|
||||||
|
crop_to_patches: Optional[bool]
|
||||||
|
min_patches: Optional[int]
|
||||||
|
max_patches: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
images_kwargs: InternVLImagesKwargs
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding_side": "left",
|
||||||
|
},
|
||||||
|
"images_kwargs": {
|
||||||
|
"crop_to_patches": True,
|
||||||
|
},
|
||||||
|
"videos_kwargs": {
|
||||||
|
"crop_to_patches": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLProcessor(ProcessorMixin):
|
||||||
|
r"""
|
||||||
|
Constructs a InternVL processor which wraps a [`AutoImageProcessor`] and
|
||||||
|
[`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
|
||||||
|
tokenizer functionalities. See the [`~InternVLProcessor.__call__`] and [`~InternVLProcessor.decode`] for more information.
|
||||||
|
Args:
|
||||||
|
image_processor ([`AutoImageProcessor`], *optional*):
|
||||||
|
The image processor is a required input.
|
||||||
|
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
|
||||||
|
The tokenizer is a required input.
|
||||||
|
image_seq_length (`int`, *optional*, defaults to 256):
|
||||||
|
The number of image token to use per image patch. it should be set so that:
|
||||||
|
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
|
fake_image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||||
|
The token to use for the image placeholder in the text. This token will be replaced by the
|
||||||
|
appropriate image tokens when processing the text with images.
|
||||||
|
fake_video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||||
|
The token to use for the video placeholder in the text. This token will be replaced by the
|
||||||
|
appropriate image tokens when processing the text with videos.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = [
|
||||||
|
"chat_template",
|
||||||
|
"image_seq_length",
|
||||||
|
"fake_image_token",
|
||||||
|
"fake_video_token",
|
||||||
|
]
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor=None,
|
||||||
|
tokenizer=None,
|
||||||
|
image_seq_length: int = 256,
|
||||||
|
chat_template=None,
|
||||||
|
fake_image_token="<image>",
|
||||||
|
fake_video_token="<video>",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.image_seq_length = image_seq_length
|
||||||
|
self.fake_image_token = fake_image_token
|
||||||
|
self.fake_video_token = fake_video_token
|
||||||
|
self.start_image_token = tokenizer.start_image_token
|
||||||
|
self.end_image_token = tokenizer.end_image_token
|
||||||
|
self.context_image_token = tokenizer.context_image_token
|
||||||
|
|
||||||
|
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||||
|
|
||||||
|
def _insert_media_placeholders(
|
||||||
|
self,
|
||||||
|
text: list[str],
|
||||||
|
image_pixel_values,
|
||||||
|
video_pixel_values,
|
||||||
|
image_num_patches: list[int],
|
||||||
|
video_num_patches: list[int],
|
||||||
|
image_num_patches_indices: np.ndarray,
|
||||||
|
video_num_patches_indices: np.ndarray,
|
||||||
|
video_patch_indices: np.ndarray,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Processes interleaved text with <image> and <video> placeholders, replacing them with appropriate
|
||||||
|
image and video tokens while keeping track of the patches used.
|
||||||
|
"""
|
||||||
|
image_index = 0
|
||||||
|
video_index = 0
|
||||||
|
processed_text = []
|
||||||
|
image_video_patches = []
|
||||||
|
# Support interleaved image and video in prompts:
|
||||||
|
# Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
|
||||||
|
for prompt in text:
|
||||||
|
new_prompt = prompt
|
||||||
|
while self.fake_image_token in new_prompt or self.fake_video_token in new_prompt:
|
||||||
|
if self.fake_image_token in new_prompt and (
|
||||||
|
self.fake_video_token not in new_prompt
|
||||||
|
or new_prompt.index(self.fake_image_token) < new_prompt.index(self.fake_video_token)
|
||||||
|
):
|
||||||
|
# Get the slice of patches corresponding to the current image
|
||||||
|
start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
|
||||||
|
end_index = image_num_patches_indices[image_index]
|
||||||
|
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||||
|
# Replace the corresponding image placeholder with the correct number of image tokens
|
||||||
|
new_prompt = new_prompt.replace(
|
||||||
|
self.fake_image_token,
|
||||||
|
f"{self.start_image_token}{self.context_image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
else:
|
||||||
|
# Get the slice of patches corresponding to the current video
|
||||||
|
# Here we need to account for both the multiple video frames and the potential multiple patches per frame
|
||||||
|
# As of now, InternVL only supports one patch per frame, but we keep the code flexible for future updates
|
||||||
|
current_patch_index = video_patch_indices[video_index - 1] if video_index > 0 else 0
|
||||||
|
end_patch_index = video_patch_indices[video_index]
|
||||||
|
start_index = video_num_patches_indices[current_patch_index] if video_index > 0 else 0
|
||||||
|
end_index = video_num_patches_indices[end_patch_index - 1]
|
||||||
|
image_video_patches.append(video_pixel_values[start_index:end_index])
|
||||||
|
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
|
||||||
|
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||||
|
video_prompt = "\n".join(
|
||||||
|
f"Frame{i + 1}: {self.start_image_token}{self.context_image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
|
||||||
|
for i in range(len(num_patches))
|
||||||
|
)
|
||||||
|
new_prompt = new_prompt.replace(self.fake_video_token, video_prompt, 1)
|
||||||
|
video_index += 1
|
||||||
|
processed_text.append(new_prompt)
|
||||||
|
|
||||||
|
return processed_text, image_video_patches, image_index, video_index
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: Optional[ImageInput] = None,
|
||||||
|
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||||
|
audio=None,
|
||||||
|
videos: Optional[VideoInput] = None,
|
||||||
|
**kwargs: Unpack[InternVLProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||||
|
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
|
||||||
|
is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
|
||||||
|
`crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to
|
||||||
|
GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. Both channels-first and channels-last formats are supported.
|
||||||
|
text (`str`, `List[str]`, `List[List[str]]`):
|
||||||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
|
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
|
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||||
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
|
If set, will return tensors of a particular framework. Acceptable values are:
|
||||||
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||||
|
|
||||||
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||||
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||||
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||||
|
`None`).
|
||||||
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
|
"""
|
||||||
|
if text is None:
|
||||||
|
raise ValueError("You have to specify text.")
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
InternVLProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(text, (list, tuple)):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
# Process images and videos separately, as videos don't support crop_to_patches
|
||||||
|
image_num_patches = []
|
||||||
|
video_num_patches = []
|
||||||
|
image_videos_inputs = {}
|
||||||
|
image_pixel_values = None
|
||||||
|
video_pixel_values = None
|
||||||
|
image_num_patches_indices = np.array([0])
|
||||||
|
video_patch_indices = np.array([0])
|
||||||
|
video_num_patches_indices = np.array([0])
|
||||||
|
if images is not None:
|
||||||
|
images = make_flat_list_of_images(images)
|
||||||
|
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||||
|
image_num_patches = image_inputs.pop("num_patches")
|
||||||
|
image_pixel_values = image_inputs.pop("pixel_values")
|
||||||
|
image_num_patches_indices = np.cumsum(image_num_patches)
|
||||||
|
if videos is not None:
|
||||||
|
videos = make_batched_videos(videos)
|
||||||
|
num_frames_per_video = [len(video) for video in videos]
|
||||||
|
video_patch_indices = np.cumsum(num_frames_per_video)
|
||||||
|
output_kwargs["images_kwargs"]["crop_to_patches"] = False
|
||||||
|
video_inputs = self.image_processor(images=videos, **output_kwargs["videos_kwargs"])
|
||||||
|
video_num_patches = video_inputs.pop("num_patches")
|
||||||
|
video_pixel_values = video_inputs.pop("pixel_values")
|
||||||
|
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||||
|
|
||||||
|
if images is not None or videos is not None:
|
||||||
|
text, image_video_patches, image_index, video_index = self._insert_media_placeholders(
|
||||||
|
text,
|
||||||
|
image_pixel_values,
|
||||||
|
video_pixel_values,
|
||||||
|
image_num_patches,
|
||||||
|
video_num_patches,
|
||||||
|
image_num_patches_indices,
|
||||||
|
video_num_patches_indices,
|
||||||
|
video_patch_indices,
|
||||||
|
)
|
||||||
|
if images is not None and image_index != len(images):
|
||||||
|
raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
|
||||||
|
if videos is not None and video_index != len(videos):
|
||||||
|
raise ValueError("Number of video placeholders in the prompt does not match the number of videos.")
|
||||||
|
|
||||||
|
# Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
|
||||||
|
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
|
return BatchFeature(data={**text_inputs, **image_videos_inputs})
|
||||||
|
|
||||||
|
def sample_indices_fn(
|
||||||
|
self, metadata: VideoMetadata, num_frames: int = None, initial_shift: Union[bool, float, int] = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The function to generate indices of frames to sample from a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (`VideoMetadata`):
|
||||||
|
`VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
|
||||||
|
num_frames (`int`, *optional*):
|
||||||
|
Number of frames to sample uniformly. If None, all frames are sampled.
|
||||||
|
initial_shift (`bool`, `float` or `int`, defaults to `0`):
|
||||||
|
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray`: Array of frame indices to sample.
|
||||||
|
"""
|
||||||
|
if initial_shift is True:
|
||||||
|
initial_shift = metadata.total_num_frames / num_frames / 2
|
||||||
|
if num_frames is not None:
|
||||||
|
indices = np.arange(
|
||||||
|
initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames
|
||||||
|
).astype(int)
|
||||||
|
else:
|
||||||
|
indices = np.arange(initial_shift, metadata.total_num_frames).astype(int)
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
refer to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||||
|
the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
return list(tokenizer_input_names) + list(image_processor_input_names)
|
||||||
|
|
||||||
|
# Add model-specific video sampling method when applying the template
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
num_frames: int = 8,
|
||||||
|
initial_shift: Union[bool, float, int] = True,
|
||||||
|
video_load_backend="pyav",
|
||||||
|
**kwargs: Unpack[AllKwargsForChatTemplate],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
|
||||||
|
conversations to turn them into a single tokenizable string.
|
||||||
|
|
||||||
|
The input is expected to be in the following format, where each message content is a list consisting of text and
|
||||||
|
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
|
||||||
|
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "Please describe this image in detail."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
|
||||||
|
The conversation to format.
|
||||||
|
chat_template (`Optional[str]`, *optional*):
|
||||||
|
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||||
|
chat template is used.
|
||||||
|
num_frames (`int`, *optional*, defaults to 8):
|
||||||
|
Number of frames to sample from a video when using the default `sample_indices_fn`.
|
||||||
|
initial_shift (`bool`, `float` or `int`, defaults to `0`):
|
||||||
|
The initial shift to apply when sampling frames using the default `sample_indices_fn`.
|
||||||
|
If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||||
|
"""
|
||||||
|
sample_indices_fn = kwargs.pop(
|
||||||
|
"sample_indices_fn", partial(self.sample_indices_fn, num_frames=num_frames, initial_shift=initial_shift)
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
chat_template,
|
||||||
|
video_load_backend=video_load_backend,
|
||||||
|
num_frames=num_frames,
|
||||||
|
sample_indices_fn=sample_indices_fn,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["InternVLProcessor"]
|
@ -130,6 +130,7 @@ VLM_CLASS_NAMES = [
|
|||||||
"gemma3",
|
"gemma3",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
"chameleon",
|
"chameleon",
|
||||||
|
"internvl",
|
||||||
"qwen2_5_omni",
|
"qwen2_5_omni",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
0
tests/models/internvl/__init__.py
Normal file
0
tests/models/internvl/__init__.py
Normal file
894
tests/models/internvl/test_modeling_internvl.py
Normal file
894
tests/models/internvl/test_modeling_internvl.py
Normal file
@ -0,0 +1,894 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch InternVL model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
InternVLConfig,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
|
require_av,
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
InternVLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLVisionText2TextModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=3,
|
||||||
|
seq_length=7,
|
||||||
|
image_seq_length=64,
|
||||||
|
vision_feature_layer=-1,
|
||||||
|
ignore_index=-100,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=0,
|
||||||
|
pad_token_id=0,
|
||||||
|
image_token_id=1,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=64,
|
||||||
|
model_type="internvl",
|
||||||
|
is_training=True,
|
||||||
|
text_config={
|
||||||
|
"model_type": "qwen2",
|
||||||
|
"vocab_size": 99,
|
||||||
|
"hidden_size": 128,
|
||||||
|
"intermediate_size": 37,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"output_channels": 64,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"rope_theta": 10000,
|
||||||
|
"mlp_ratio": 4,
|
||||||
|
"tie_word_embeddings": True,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": 0,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
},
|
||||||
|
vision_config={
|
||||||
|
"hidden_size": 32,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"intermediate_size": 128,
|
||||||
|
"image_size": 64,
|
||||||
|
"patch_size": 4,
|
||||||
|
"num_channels": 3,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"use_absolute_position_embeddings": True,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.model_type = model_type
|
||||||
|
self.text_config = text_config
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.vision_feature_layer = vision_feature_layer
|
||||||
|
self.is_training = is_training
|
||||||
|
self.image_seq_length = image_seq_length
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.image_size = image_size
|
||||||
|
self.seq_length = seq_length + image_seq_length
|
||||||
|
|
||||||
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
|
self.vocab_size = text_config["vocab_size"]
|
||||||
|
self.hidden_size = text_config["hidden_size"]
|
||||||
|
self.num_attention_heads = text_config["num_attention_heads"]
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return InternVLConfig(
|
||||||
|
text_config=self.text_config,
|
||||||
|
vision_config=self.vision_config,
|
||||||
|
model_type=self.model_type,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
image_token_id=self.image_token_id,
|
||||||
|
image_seq_length=self.image_seq_length,
|
||||||
|
vision_feature_layer=self.vision_feature_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
config = self.get_config()
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
# input_ids[:, -1] = self.pad_token_id
|
||||||
|
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||||
|
input_ids[:, : self.image_seq_length] = self.image_token_id
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||||
|
model = InternVLForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||||
|
config.torch_dtype = torch.float16
|
||||||
|
model = InternVLForConditionalGeneration(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pixel_values=pixel_values.to(torch.bfloat16),
|
||||||
|
return_dict=True,
|
||||||
|
)["logits"]
|
||||||
|
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (InternVLForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"image-text-to-text": InternVLForConditionalGeneration,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = InternVLVisionText2TextModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=InternVLConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
del inputs["pixel_values"]
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
inputs["inputs_embeds"] = wte(input_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model(**inputs)
|
||||||
|
|
||||||
|
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||||
|
# while some other models require pixel_values to be present
|
||||||
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
del inputs["pixel_values"]
|
||||||
|
|
||||||
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||||
|
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||||
|
torch.testing.assert_close(out_embeds, out_ids)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||||
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Qwen2 flash attention does not support right padding")
|
||||||
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.small_model_checkpoint = "OpenGVLab/InternVL3-1B-hf"
|
||||||
|
self.medium_model_checkpoint = "OpenGVLab/InternVL3-2B-hf"
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
def test_qwen2_small_model_integration_generate(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_qwen2_small_model_integration_forward(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = model(**inputs)
|
||||||
|
|
||||||
|
actual_logits = output.logits[0, -1, :5].cpu()
|
||||||
|
expected_logits = torch.tensor([11.9375, 14.8750, 14.0625, 10.7500, 6.9062], dtype=torch.bfloat16)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(actual_logits, expected_logits, atol=0.1),
|
||||||
|
f"Actual logits: {actual_logits}"
|
||||||
|
f"\nExpected logits: {expected_logits}"
|
||||||
|
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen2_small_model_integration_generate_text_only(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
prompt = "<|im_start|>user\nWrite a haiku<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
inputs = processor(text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "Whispers of dawn,\nSilent whispers of the night,\nNew day's light begins."
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_qwen2_small_model_integration_generate_chat_template(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||||
|
{"type": "text", "text": "Please describe the image explicitly."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "The image shows two cats lying on a pink blanket. The cat on the left is a tabby"
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_qwen2_small_model_integration_batched_generate(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
# Prepare inputs
|
||||||
|
prompt = [
|
||||||
|
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
]
|
||||||
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
|
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||||
|
|
||||||
|
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
|
||||||
|
torch_device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
# Check first output
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen2_small_model_integration_batched_generate_multi_image(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
# Prepare inputs
|
||||||
|
prompt = [
|
||||||
|
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
"<|im_start|>user\n<image><image>\nWhat are the differences between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
]
|
||||||
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
|
image2 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
image3 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = processor(text=prompt, images=[[image1], [image2, image3]], padding=True, return_tensors="pt").to(
|
||||||
|
torch_device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
# Check first output
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||||
|
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\n\nWhat are the differences between these two images?\nassistant\nThe images show the Statue of Liberty and the Golden Gate Bridge from different angles. Here are the differences:\n\n1. **Angle' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_qwen2_medium_model_integration_video(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
|
||||||
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.medium_model_checkpoint, quantization_config=quantization_config
|
||||||
|
)
|
||||||
|
# Prepare inputs
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(torch_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||||
|
expected_output = 'The man is performing a forehand shot.' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_qwen2_small_model_integration_interleaved_images_videos(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, torch_dtype=torch.bfloat16, device_map=torch_device
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What are the differences between these two images?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://llava-vl.github.io/static/images/view.jpg",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||||
|
expected_output = 'user\n\n\nWhat are the differences between these two images?\nassistant\nThe images depict two distinct scenes:\n\n1. **Left Image**: This shows the Statue of Liberty on Liberty Island, with the' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nA forehand shot' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check third output
|
||||||
|
decoded_output = processor.decode(output[2], skip_special_tokens=True)
|
||||||
|
expected_output = "user\n\nWrite a haiku for this image\nassistant\nSilky lake, \nWooden pier, \nNature's peace." # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.small_model_checkpoint = "OpenGVLab/InternVL2_5-2B-MPO-hf"
|
||||||
|
self.medium_model_checkpoint = "OpenGVLab/InternVL2_5-8B-MPO-hf"
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
def test_llama_small_model_integration_generate(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "The image shows two cats sleeping on a pink couch. They are lying side by side, with their"
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_llama_small_model_integration_forward(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
prompt = "<|im_start|>user\n<image>\nPlease describe the image explicitly.<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = model(**inputs)
|
||||||
|
|
||||||
|
actual_logits = output.logits[0, -1, :5].cpu()
|
||||||
|
expected_logits = torch.tensor([-9.8750, -0.4258, 1.4844, -10.3125, -10.3125], dtype=torch.bfloat16)
|
||||||
|
# The original implementation and the transformers implementation do not match exactly, hence the higher tolerance.
|
||||||
|
# The difference is likely due to the different implementations of the attention mechanism (different order of operations)
|
||||||
|
# between the transformers Llama model and the original InternLM model.
|
||||||
|
# The difference has almost no effect on the output tokens, but it does affect the logits a lot more.
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(actual_logits, expected_logits, atol=1),
|
||||||
|
f"Actual logits: {actual_logits}"
|
||||||
|
f"\nExpected logits: {expected_logits}"
|
||||||
|
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_llama_small_model_integration_generate_text_only(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
prompt = "<|im_start|>user\nWrite a haiku<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
inputs = processor(text=prompt, return_tensors="pt").to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "Autumn leaves fall,\nNature's breath, a season's sigh,\nSilent woods awake."
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_llama_small_model_integration_generate_chat_template(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||||
|
{"type": "text", "text": "Please describe the image explicitly."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(torch_device, dtype=torch.bfloat16)
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_ids = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
decoded_output = processor.decode(
|
||||||
|
generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
expected_output = "The image shows two cats sleeping on a pink couch. They are lying side by side, with their"
|
||||||
|
self.assertEqual(decoded_output, expected_output)
|
||||||
|
|
||||||
|
def test_llama_small_model_integration_batched_generate(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
# Prepare inputs
|
||||||
|
prompt = [
|
||||||
|
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
"<|im_start|>user\n<image>\nDescribe this image<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
]
|
||||||
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
|
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||||
|
|
||||||
|
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
|
||||||
|
torch_device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
# Check first output
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese gate in the background, adorned with red and gold colors and Chinese characters' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_llama_small_model_integration_batched_generate_multi_image(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
# Prepare inputs
|
||||||
|
prompt = [
|
||||||
|
"<|im_start|>user\n<image>\nWrite a haiku for this image<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
"<|im_start|>user\n<image><image>\nWhat are the difference between these two images?<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
]
|
||||||
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
|
image2 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
image3 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = processor(text=prompt, images=[[image1], [image2, image3]], padding=True, return_tensors="pt").to(
|
||||||
|
torch_device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
# Check first output
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||||
|
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, still waters.' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. After closely examining the images again, I can see that there are several differences' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_llama_medium_model_integration_video(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.medium_model_checkpoint)
|
||||||
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.medium_model_checkpoint, quantization_config=quantization_config
|
||||||
|
)
|
||||||
|
# Prepare inputs
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(torch_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
decoded_output = processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||||
|
expected_output = "The man is performing a forehand shot."
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_llama_small_model_integration_interleaved_images_videos(self):
|
||||||
|
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
|
||||||
|
model = InternVLForConditionalGeneration.from_pretrained(
|
||||||
|
self.small_model_checkpoint, torch_dtype=torch.bfloat16, device_map=torch_device
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What are the difference between these two images?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://llava-vl.github.io/static/images/view.jpg",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||||
|
|
||||||
|
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
# Batching seems to alter the output slightly, but it is also the case in the original implementation. This seems to be expected: https://github.com/huggingface/transformers/issues/23017#issuecomment-1649630232
|
||||||
|
expected_output = 'user\n\n\nWhat are the difference between these two images?\nassistant\nI apologize for the confusion in my previous response. Upon closer inspection, the differences between the two images are:\n\n1. **' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check second output
|
||||||
|
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\nFrame1: \nFrame2: \nFrame3: \nFrame4: \nFrame5: \nFrame6: \nFrame7: \nFrame8: \nWhat type of shot is the man performing?\nassistant\nThe man is performing a forehand shot. This is a common shot in tennis where the player swings the racket across their' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check third output
|
||||||
|
decoded_output = processor.decode(output[2], skip_special_tokens=True)
|
||||||
|
expected_output = 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nA wooden path leads to the sea,\nPeaceful, untouched dreams.' # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
decoded_output,
|
||||||
|
expected_output,
|
||||||
|
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||||
|
)
|
327
tests/models/internvl/test_processor_internvl.py
Normal file
327
tests/models/internvl/test_processor_internvl.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
|
||||||
|
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from transformers import GotOcr2ImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
|
processor_class = InternVLProcessor
|
||||||
|
videos_input_name = "pixel_values"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
image_processor = GotOcr2ImageProcessor(
|
||||||
|
do_resize=True,
|
||||||
|
size={"height": 20, "width": 20},
|
||||||
|
max_patches=2,
|
||||||
|
do_rescale=True,
|
||||||
|
rescale_factor=1 / 255,
|
||||||
|
do_normalize=True,
|
||||||
|
do_center_crop=True,
|
||||||
|
image_mean=[0.485, 0.456, 0.406],
|
||||||
|
image_std=[0.229, 0.224, 0.225],
|
||||||
|
do_convert_rgb=True,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B-hf", padding_side="left")
|
||||||
|
processor_kwargs = cls.prepare_processor_dict()
|
||||||
|
processor = InternVLProcessor.from_pretrained(
|
||||||
|
"OpenGVLab/InternVL3-1B-hf",
|
||||||
|
image_processor=image_processor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**processor_kwargs,
|
||||||
|
)
|
||||||
|
processor.save_pretrained(cls.tmpdirname)
|
||||||
|
cls.image_token = processor.fake_image_token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_processor_dict():
|
||||||
|
return {"image_seq_length": 10}
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def get_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
@require_torch
|
||||||
|
def test_process_interleaved_images_videos(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What are the differences between these two images?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What type of shot is the man performing?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://llava-vl.github.io/static/images/view.jpg",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Write a haiku for this image"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs_batched = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process non batched inputs to check if the pixel_values and input_ids are reconstructed in the correct order when batched together
|
||||||
|
images_patches_index = 0
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
message,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
# We slice with [-inputs["input_ids"].shape[1] :] as the input_ids are left padded
|
||||||
|
torch.testing.assert_close(
|
||||||
|
inputs["input_ids"][0], inputs_batched["input_ids"][i][-inputs["input_ids"].shape[1] :]
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
inputs["pixel_values"],
|
||||||
|
inputs_batched["pixel_values"][
|
||||||
|
images_patches_index : images_patches_index + inputs["pixel_values"].shape[0]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
images_patches_index += inputs["pixel_values"].shape[0]
|
||||||
|
|
||||||
|
# Override video chat_template tests as InternVLProcessor returns flattened video features
|
||||||
|
@require_av
|
||||||
|
def test_apply_chat_template_video_special_processing(self):
|
||||||
|
"""
|
||||||
|
Tests that models can use their own preprocessing to preprocess conversations.
|
||||||
|
"""
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
video_file_path = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video", "path": video_file_path},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _process_messages_for_chat_template(
|
||||||
|
conversation,
|
||||||
|
batch_images,
|
||||||
|
batch_videos,
|
||||||
|
batch_video_metadata,
|
||||||
|
**chat_template_kwargs,
|
||||||
|
):
|
||||||
|
# Let us just always return a dummy prompt
|
||||||
|
new_msg = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
||||||
|
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return new_msg
|
||||||
|
|
||||||
|
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
|
||||||
|
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
||||||
|
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
||||||
|
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
||||||
|
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||||
|
|
||||||
|
def test_apply_chat_template_video_frame_sampling(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
signature = inspect.signature(processor.__call__)
|
||||||
|
if "videos" not in {*signature.parameters.keys()} or (
|
||||||
|
signature.parameters.get("videos") is not None
|
||||||
|
and signature.parameters["videos"].annotation == inspect._empty
|
||||||
|
):
|
||||||
|
self.skipTest("Processor doesn't accept videos at input")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
num_frames = 3
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
num_frames=num_frames,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
|
||||||
|
|
||||||
|
# Load with `video_fps` arg
|
||||||
|
video_fps = 1
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
num_frames=None, # force to use default num_frames
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), video_fps * 10)
|
||||||
|
|
||||||
|
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load without any arg should use the default loading method
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||||
|
|
||||||
|
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
|
||||||
|
# because we assume they come from one video
|
||||||
|
messages[0][0]["content"][0] = {
|
||||||
|
"type": "video",
|
||||||
|
"url": [
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
"https://www.ilankelman.org/stopsigns/australia.jpg",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2)
|
@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
|
|||||||
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||||
"Emu3VQVAE", # Building part of bigger (tested) model
|
"Emu3VQVAE", # Building part of bigger (tested) model
|
||||||
"Emu3TextModel", # Building part of bigger (tested) model
|
"Emu3TextModel", # Building part of bigger (tested) model
|
||||||
|
"InternVLVisionModel", # Building part of bigger (tested) model
|
||||||
"JanusVisionModel", # Building part of bigger (tested) model
|
"JanusVisionModel", # Building part of bigger (tested) model
|
||||||
"TimesFmModel", # Building part of bigger (tested) model
|
"TimesFmModel", # Building part of bigger (tested) model
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user