mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Llava Onevision: add model (#32673)
* working version * fix copies * update * tests * update docs * codestyle * add more tests * add returns for docs * clean up * Update src/transformers/models/llava_onevision/processing_llava_onevision.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * updates * codestyle * style * shouldn't be reversed * [run-slow] llava_onevision * [run-slow] llava_onevision * add pooling in videos * [run-slow] llava_onevision * num-logits-to-keep * [run-slow] llava_onevision * [run-slow] llava_onevision * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * video matched orig impl * fix tests * chat template was modified * Update docs/source/en/model_doc/llava_onevision.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add morer info in the doc page --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
9230d78e76
commit
43df47d8e7
@ -836,6 +836,8 @@
|
||||
title: LLaVA-NeXT
|
||||
- local: model_doc/llava_next_video
|
||||
title: LLaVa-NeXT-Video
|
||||
- local: model_doc/llava_onevision
|
||||
title: LLaVA-Onevision
|
||||
- local: model_doc/lxmert
|
||||
title: LXMERT
|
||||
- local: model_doc/matcha
|
||||
|
@ -189,6 +189,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [LLaVa](model_doc/llava) | ✅ | ❌ | ❌ |
|
||||
| [LLaVA-NeXT](model_doc/llava_next) | ✅ | ❌ | ❌ |
|
||||
| [LLaVa-NeXT-Video](model_doc/llava_next_video) | ✅ | ❌ | ❌ |
|
||||
| [LLaVA-Onevision](model_doc/llava_onevision) | ✅ | ❌ | ❌ |
|
||||
| [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ |
|
||||
| [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ |
|
||||
| [LUKE](model_doc/luke) | ✅ | ❌ | ❌ |
|
||||
|
319
docs/source/en/model_doc/llava_onevision.md
Normal file
319
docs/source/en/model_doc/llava_onevision.md
Normal file
@ -0,0 +1,319 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# LLaVA-Onevision
|
||||
|
||||
## Overview
|
||||
|
||||
The LLaVA-Onevision model was proposed in [LLaVA-OneVision: Easy Visual Task Transfer](https://arxiv.org/abs/2408.03326) by <Bo Li, Yuanhan Zhang, Dong Guo, Renrui Zhang, Feng Li, Hao Zhang, Kaichen Zhang, Yanwei Li, Ziwei Liu, Chunyuan Li
|
||||
|
||||
LLaVA-Onevision is a Vision-Language Model that can generate text conditioned on one or several images/videos. The model consists of SigLIP vision encoder and a Qwen2 language backbone. The images are processed with anyres-9 technique where the image is split into 9 patches to better process high resolution images and capture as much details as possible. However, videos are pooled to a total sequence length of 196 tokens each frame for more memory efficient computation. LLaVA-Onevision is available in three sizes: 0.5B, 7B and 72B and achieves remarkable performance on benchmark evaluations.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We present LLaVA-OneVision, a family of open large multimodal models (LMMs)
|
||||
developed by consolidating our insights into data, models, and visual representations in the LLaVA-NeXT blog series. Our experimental results demonstrate that
|
||||
LLaVA-OneVision is the first single model that can simultaneously push the performance boundaries of open LMMs in three important computer vision scenarios:
|
||||
single-image, multi-image, and video scenarios. Importantly, the design of LLaVAOneVision allows strong transfer learning across different modalities/scenarios,
|
||||
yielding new emerging capabilities. In particular, strong video understanding and
|
||||
cross-scenario capabilities are demonstrated through task transfer from images to
|
||||
videos.*
|
||||
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llava-ov-acrhitecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> LLaVA=Onevision architecture. Taken from the <a href="https://arxiv.org/abs/2408.03326">original paper.</a> </small>
|
||||
|
||||
Tips:
|
||||
|
||||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
- Llava-Onevision uses different number of patches for images and thus has to pad the inputs inside modeling code, aside from the padding done when processing the inputs. The default setting is "left-padding" if model is in `eval()` mode, otherwise "right-padding".
|
||||
|
||||
</Tip>
|
||||
|
||||
- Note that the model should use a specific prompt format, on which the large language model (LLM) was trained. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities.
|
||||
|
||||
We will use [llava-onevision-qwen2-7b-si-hf](https://huggingface.co/llava-hf/llava-onevision-qwen2-7b-si-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-si-hf")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What’s shown in this image?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
|
||||
},
|
||||
{
|
||||
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe the image in more details."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
|
||||
print(text_prompt)
|
||||
>>> "<|im_start|>user\n<image>What is shown in this image?<|im_end|>\n<|im_start|>assistant\nPage showing the list of options.<|im_end|>"
|
||||
```
|
||||
|
||||
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
|
||||
The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main).
|
||||
|
||||
|
||||
## Usage example
|
||||
|
||||
### Single image inference
|
||||
|
||||
Here's how to load the model and perform inference in half-precision (`torch.float16`):
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
model.to("cuda:0")
|
||||
|
||||
# prepare image and text prompt, using the appropriate prompt template
|
||||
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda:0", torch.float16)
|
||||
|
||||
# autoregressively complete prompt
|
||||
output = model.generate(**inputs, max_new_tokens=100)
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
'user\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart, also known as a spider chart or a star chart, which is used to compare multiple quantitative variables. Each axis represents a different variable, and the chart is filled with'
|
||||
```
|
||||
|
||||
### Multi image inference
|
||||
|
||||
LLaVa-Onevision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
||||
|
||||
# Load the model in half-precision
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
|
||||
# Get three different images
|
||||
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
image_stop = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image_cats = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
image_snowman = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not
|
||||
conversation_1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "There is a red stop sign in the image."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What about this image? How many cats do you see?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
conversation_2 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
|
||||
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
|
||||
prompts = [prompt_1, prompt_2]
|
||||
|
||||
# We can simply feed images in the order they have to be used in the text prompt
|
||||
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)
|
||||
|
||||
# Generate
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=30)
|
||||
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
['user\n\nWhat is shown in this image?\nassistant\nThere is a red stop sign in the image.\nuser\n\nWhat about this image? How many cats do you see?\nassistant\ntwo', 'user\n\nWhat is shown in this image?\nassistant\n']
|
||||
```
|
||||
|
||||
### Video inference
|
||||
|
||||
LLaVa-Onevision also can perform inference with videos as input, where video frames are treated as multiple images. Here is how you can do it:
|
||||
|
||||
```python
|
||||
import av
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
||||
|
||||
# Load the model in half-precision
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
|
||||
|
||||
def read_video_pyav(container, indices):
|
||||
'''
|
||||
Decode the video with PyAV decoder.
|
||||
Args:
|
||||
container (`av.container.input.InputContainer`): PyAV container.
|
||||
indices (`List[int]`): List of frame indices to decode.
|
||||
Returns:
|
||||
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
'''
|
||||
frames = []
|
||||
container.seek(0)
|
||||
start_index = indices[0]
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= start_index and i in indices:
|
||||
frames.append(frame)
|
||||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
|
||||
# Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos, up to 32 frames)
|
||||
video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
|
||||
video = read_video_pyav(container, indices)
|
||||
|
||||
# For videos we have to feed a "video" type instead of "image"
|
||||
conversation = [
|
||||
{
|
||||
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "Why is this video funny?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = processor(videos=list(video), text=prompt, return_tensors="pt").to("cuda:0", torch.float16)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=60)
|
||||
processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
["user\n\nWhy is this video funny?\nassistant\nThe video appears to be humorous because it shows a young child, who is wearing glasses and holding a book, seemingly reading with a serious and focused expression. The child's glasses are a bit oversized for their face, which adds a comical touch, as it's a common trope to see children wearing"]
|
||||
```
|
||||
|
||||
## Model optimization
|
||||
|
||||
### Quantization using Bitsandbytes
|
||||
|
||||
The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
|
||||
|
||||
```python
|
||||
from transformers import LlavaOnevisionForConditionalGeneration, BitsAndBytesConfig
|
||||
|
||||
# specify how to quantize the model
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
|
||||
```
|
||||
|
||||
### Use Flash-Attention 2 to further speed-up generation
|
||||
|
||||
First make sure to install flash-attn. Refer to the [original repository of Flash Attention](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:
|
||||
|
||||
```python
|
||||
from transformers import LlavaOnevisionForConditionalGeneration
|
||||
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
use_flash_attention_2=True
|
||||
).to(0)
|
||||
```
|
||||
|
||||
|
||||
## LlavaOnevisionConfig
|
||||
|
||||
[[autodoc]] LlavaOnevisionConfig
|
||||
|
||||
## LlavaOnevisionProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionProcessor
|
||||
|
||||
## LlavaOnevisionImageProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionImageProcessor
|
||||
|
||||
## LlavaOnevisionVideoProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionVideoProcessor
|
||||
|
||||
## LlavaOnevisionForConditionalGeneration
|
||||
|
||||
[[autodoc]] LlavaOnevisionForConditionalGeneration
|
||||
- forward
|
@ -60,6 +60,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
|
||||
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
|
||||
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
|
||||
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
|
||||
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
|
||||
@ -226,6 +227,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
|
@ -533,6 +533,7 @@ _import_structure = {
|
||||
"LlavaNextVideoConfig",
|
||||
"LlavaNextVideoProcessor",
|
||||
],
|
||||
"models.llava_onevision": ["LlavaOnevisionConfig", "LlavaOnevisionProcessor"],
|
||||
"models.longformer": [
|
||||
"LongformerConfig",
|
||||
"LongformerTokenizer",
|
||||
@ -1183,6 +1184,9 @@ else:
|
||||
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
|
||||
_import_structure["models.llava_next"].append("LlavaNextImageProcessor")
|
||||
_import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor")
|
||||
_import_structure["models.llava_onevision"].extend(
|
||||
["LlavaOnevisionImageProcessor", "LlavaOnevisionVideoProcessor"]
|
||||
)
|
||||
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
|
||||
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
|
||||
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
|
||||
@ -2532,6 +2536,12 @@ else:
|
||||
"LlavaNextVideoPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.llava_onevision"].extend(
|
||||
[
|
||||
"LlavaOnevisionForConditionalGeneration",
|
||||
"LlavaOnevisionPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.longformer"].extend(
|
||||
[
|
||||
"LongformerForMaskedLM",
|
||||
@ -5308,6 +5318,10 @@ if TYPE_CHECKING:
|
||||
LlavaNextVideoConfig,
|
||||
LlavaNextVideoProcessor,
|
||||
)
|
||||
from .models.llava_onevision import (
|
||||
LlavaOnevisionConfig,
|
||||
LlavaOnevisionProcessor,
|
||||
)
|
||||
from .models.longformer import (
|
||||
LongformerConfig,
|
||||
LongformerTokenizer,
|
||||
@ -5993,6 +6007,7 @@ if TYPE_CHECKING:
|
||||
from .models.levit import LevitFeatureExtractor, LevitImageProcessor
|
||||
from .models.llava_next import LlavaNextImageProcessor
|
||||
from .models.llava_next_video import LlavaNextVideoImageProcessor
|
||||
from .models.llava_onevision import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor
|
||||
from .models.mask2former import Mask2FormerImageProcessor
|
||||
from .models.maskformer import (
|
||||
MaskFormerFeatureExtractor,
|
||||
@ -7113,6 +7128,10 @@ if TYPE_CHECKING:
|
||||
LlavaNextVideoForConditionalGeneration,
|
||||
LlavaNextVideoPreTrainedModel,
|
||||
)
|
||||
from .models.llava_onevision import (
|
||||
LlavaOnevisionForConditionalGeneration,
|
||||
LlavaOnevisionPreTrainedModel,
|
||||
)
|
||||
from .models.longformer import (
|
||||
LongformerForMaskedLM,
|
||||
LongformerForMultipleChoice,
|
||||
|
@ -1030,6 +1030,7 @@ class StaticCache(Cache):
|
||||
|
||||
self.batch_size = batch_size or max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
self.head_dim = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
|
@ -1450,8 +1450,8 @@ class GenerationMixin:
|
||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||
|
||||
cache_kwargs = {
|
||||
"config": self.config,
|
||||
"batch_size": batch_size,
|
||||
"config": self.config if hasattr(self.config, "text_config") else self.config,
|
||||
"max_batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"device": device,
|
||||
"dtype": cache_dtype,
|
||||
@ -2353,7 +2353,11 @@ class GenerationMixin:
|
||||
this_peer_finished = False
|
||||
|
||||
# prepare layers for DoLa decoding
|
||||
final_layer = self.config.num_hidden_layers
|
||||
final_layer = (
|
||||
self.config.text_config.num_hidden_layers
|
||||
if hasattr(self.config, "text_config")
|
||||
else self.config.num_hidden_layers
|
||||
)
|
||||
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
|
||||
# as the early exit from word embeddings will become identity function
|
||||
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
|
||||
|
@ -132,6 +132,7 @@ from . import (
|
||||
llava,
|
||||
llava_next,
|
||||
llava_next_video,
|
||||
llava_onevision,
|
||||
longformer,
|
||||
longt5,
|
||||
luke,
|
||||
|
@ -149,6 +149,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("llava", "LlavaConfig"),
|
||||
("llava_next", "LlavaNextConfig"),
|
||||
("llava_next_video", "LlavaNextVideoConfig"),
|
||||
("llava_onevision", "LlavaOnevisionConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("longt5", "LongT5Config"),
|
||||
("luke", "LukeConfig"),
|
||||
@ -444,6 +445,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("llava", "LLaVa"),
|
||||
("llava_next", "LLaVA-NeXT"),
|
||||
("llava_next_video", "LLaVa-NeXT-Video"),
|
||||
("llava_onevision", "LLaVA-Onevision"),
|
||||
("longformer", "Longformer"),
|
||||
("longt5", "LongT5"),
|
||||
("luke", "LUKE"),
|
||||
|
@ -99,6 +99,7 @@ else:
|
||||
("llava", ("CLIPImageProcessor",)),
|
||||
("llava_next", ("LlavaNextImageProcessor",)),
|
||||
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
|
||||
("llava_onevision", ("LlavaOnevisionImageProcessor",)),
|
||||
("mask2former", ("Mask2FormerImageProcessor",)),
|
||||
("maskformer", ("MaskFormerImageProcessor",)),
|
||||
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
|
@ -314,6 +314,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("luke", "LukeForMaskedLM"),
|
||||
("lxmert", "LxmertForPreTraining"),
|
||||
@ -729,6 +730,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||
|
@ -73,6 +73,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("llava", "LlavaProcessor"),
|
||||
("llava_next", "LlavaNextProcessor"),
|
||||
("llava_next_video", "LlavaNextVideoProcessor"),
|
||||
("llava_onevision", "LlavaOnevisionProcessor"),
|
||||
("markuplm", "MarkupLMProcessor"),
|
||||
("mctct", "MCTCTProcessor"),
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
|
@ -257,6 +257,7 @@ else:
|
||||
),
|
||||
),
|
||||
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava-onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
72
src/transformers/models/llava_onevision/__init__.py
Normal file
72
src/transformers/models/llava_onevision/__init__.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_llava_onevision": ["LlavaOnevisionConfig"],
|
||||
"processing_llava_onevision": ["LlavaOnevisionProcessor"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_llava_onevision"] = ["LlavaOnevisionImageProcessor"]
|
||||
|
||||
_import_structure["video_processing_llava_onevision"] = ["LlavaOnevisionVideoProcessor"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_llava_onevision"] = [
|
||||
"LlavaOnevisionForConditionalGeneration",
|
||||
"LlavaOnevisionPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_llava_onevision import LlavaOnevisionConfig
|
||||
from .processing_llava_onevision import LlavaOnevisionProcessor
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_llava_onevision import LlavaOnevisionImageProcessor
|
||||
from .video_processing_llava_onevision import LlavaOnevisionVideoProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_llava_onevision import (
|
||||
LlavaOnevisionForConditionalGeneration,
|
||||
LlavaOnevisionPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
@ -0,0 +1,183 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaOnevisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LlavaOnevisionForConditionalGeneration`]. It is used to instantiate an
|
||||
Llava-NeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the [llava-hf/llava-onevision-qwen2-7b-ov-hf](https://huggingface.co/llava-hf/llava-onevision-qwen2-7b-ov-hf)
|
||||
model.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SiglipVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
||||
The config object or dictionary of the text backbone.
|
||||
image_token_index (`int`, *optional*, defaults to 151646):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_index (`int`, *optional*, defaults to 151647):
|
||||
The video token index to encode the video prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
||||
If `"full"`, the full vision features are used.
|
||||
vision_feature_layer (`int`, *optional*, defaults to -1):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
image_grid_pinpoints (`List`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionConfig, SiglipVisionConfig, Qwen2Config
|
||||
|
||||
>>> # Initializing a CLIP-vision config
|
||||
>>> vision_config = SiglipVisionConfig()
|
||||
|
||||
>>> # Initializing a Llama config
|
||||
>>> text_config = Qwen2Config()
|
||||
|
||||
>>> # Initializing a Llava-Next llava-hf/llava-onevision-qwen2-7b-ov-hf style configuration
|
||||
>>> configuration = LlavaOnevisionConfig(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the llava-hf/llava-onevision-qwen2-7b-ov-hf style configuration
|
||||
>>> model = LlavaOnevisionForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "llava_onevision"
|
||||
is_composition = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
image_token_index=151646,
|
||||
video_token_index=151647,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="full",
|
||||
vision_feature_layer=-1,
|
||||
vision_aspect_ratio="anyres_max_9",
|
||||
image_grid_pinpoints=None,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_token_index = image_token_index
|
||||
self.video_token_index = video_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
"vision_feature_select_strategy should be one of 'default', 'full'."
|
||||
f"Got: {vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.vision_aspect_ratio = vision_aspect_ratio
|
||||
image_grid_pinpoints = (
|
||||
image_grid_pinpoints
|
||||
if image_grid_pinpoints is not None
|
||||
else [
|
||||
[384, 384],
|
||||
[384, 768],
|
||||
[384, 1152],
|
||||
[384, 1536],
|
||||
[384, 1920],
|
||||
[384, 2304],
|
||||
[768, 384],
|
||||
[768, 768],
|
||||
[768, 1152],
|
||||
[768, 1536],
|
||||
[768, 1920],
|
||||
[768, 2304],
|
||||
[1152, 384],
|
||||
[1152, 768],
|
||||
[1152, 1152],
|
||||
[1152, 1536],
|
||||
[1152, 1920],
|
||||
[1152, 2304],
|
||||
[1536, 384],
|
||||
[1536, 768],
|
||||
[1536, 1152],
|
||||
[1536, 1536],
|
||||
[1536, 1920],
|
||||
[1536, 2304],
|
||||
[1920, 384],
|
||||
[1920, 768],
|
||||
[1920, 1152],
|
||||
[1920, 1536],
|
||||
[1920, 1920],
|
||||
[1920, 2304],
|
||||
[2304, 384],
|
||||
[2304, 768],
|
||||
[2304, 1152],
|
||||
[2304, 1536],
|
||||
[2304, 1920],
|
||||
[2304, 2304],
|
||||
]
|
||||
)
|
||||
self.image_grid_pinpoints = image_grid_pinpoints
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config["model_type"] = (
|
||||
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
|
||||
)
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = CONFIG_MAPPING["siglip_vision_model"](
|
||||
hidden_size=1152,
|
||||
intermediate_size=4304,
|
||||
patch_size=14,
|
||||
image_size=384,
|
||||
num_hidden_layers=26,
|
||||
num_attention_heads=14,
|
||||
vision_use_head=False,
|
||||
)
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["qwen2"]()
|
||||
|
||||
self.text_config = text_config
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
@ -0,0 +1,360 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert LLaVa-Onevision checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
LlavaOnevisionConfig,
|
||||
LlavaOnevisionForConditionalGeneration,
|
||||
LlavaOnevisionImageProcessor,
|
||||
LlavaOnevisionProcessor,
|
||||
LlavaOnevisionVideoProcessor,
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"model.vision_tower.": "",
|
||||
"model.mm_projector": "multi_modal_projector",
|
||||
"model": "model.model",
|
||||
"vision_model.model": "vision_model",
|
||||
"lm_head": "language_model.lm_head",
|
||||
"model.model": "language_model.model",
|
||||
"multi_modal_projector.0": "multi_modal_projector.linear_1",
|
||||
"multi_modal_projector.2": "multi_modal_projector.linear_2",
|
||||
"language_model.model.image_newline": "image_newline",
|
||||
}
|
||||
|
||||
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# tied wieghts so lm.head is not saved. Let's clone to load state dict
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.endswith(".inv_freq"):
|
||||
continue
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
new_state_dict[key] = value.to(torch.float16)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_image():
|
||||
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
# load original config
|
||||
filepath = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model")
|
||||
# read json
|
||||
with open(filepath) as f:
|
||||
data = json.load(f)
|
||||
print(data)
|
||||
|
||||
if model_id in ["lmms-lab/llava-onevision-qwen2-0.5b-ov", "lmms-lab/llava-onevision-qwen2-0.5b-si"]:
|
||||
text_model_id = "Qwen/Qwen2-0.5B-Instruct"
|
||||
elif model_id in ["lmms-lab/llava-onevision-qwen2-7b-ov", "lmms-lab/llava-onevision-qwen2-7b-si"]:
|
||||
text_model_id = "Qwen/Qwen2-7B-Instruct"
|
||||
elif model_id in ["lmms-lab/llava-onevision-qwen2-72b-ov", "lmms-lab/llava-onevision-qwen2-72b-si"]:
|
||||
text_model_id = "Qwen/Qwen2-72B-Instruct"
|
||||
|
||||
vision_model_id = data["mm_vision_tower"]
|
||||
torch.set_default_dtype(torch.float16)
|
||||
text_config = AutoConfig.from_pretrained(text_model_id)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=True)
|
||||
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_tokens(AddedToken("<video>", special=True, normalized=False), special_tokens=True)
|
||||
|
||||
image_processor = LlavaOnevisionImageProcessor.from_pretrained(vision_model_id)
|
||||
video_processor = LlavaOnevisionVideoProcessor.from_pretrained(vision_model_id)
|
||||
processor = LlavaOnevisionProcessor(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
image_processor=image_processor,
|
||||
num_image_tokens=729,
|
||||
vision_feature_select_strategy="full",
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
vision_config = SiglipVisionConfig(
|
||||
hidden_size=1152,
|
||||
image_size=384,
|
||||
intermediate_size=4304,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=26, # drop the last layer
|
||||
patch_size=14,
|
||||
vision_use_head=False, # no head
|
||||
).to_dict()
|
||||
|
||||
config = LlavaOnevisionConfig(
|
||||
text_config=text_config.to_dict(),
|
||||
vision_config=vision_config,
|
||||
use_image_newline_parameter=True,
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
model = LlavaOnevisionForConditionalGeneration(config)
|
||||
|
||||
# load original state dict
|
||||
state_dict = load_original_state_dict(model_id)
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
model.eval()
|
||||
|
||||
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
|
||||
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
|
||||
n = pre_expansion_embeddings.size()[0]
|
||||
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
|
||||
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
|
||||
|
||||
# We add an image token so we resize the model
|
||||
# Pad to 64 for performance reasons
|
||||
# Qwen-based models have extra unused space in the vocab size already, so no need to resize
|
||||
pad_shape = 64
|
||||
vocab_size = config.text_config.vocab_size
|
||||
num_tokens = vocab_size + 2
|
||||
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
|
||||
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
|
||||
tuple(
|
||||
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
|
||||
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
# Load everything back for inference tests in float32 because prev script was written as that
|
||||
# Though it's mostly loaded in fp16 as original weights are in fp16
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
pytorch_dump_folder_path, torch_dtype="float16", device_map="auto"
|
||||
)
|
||||
processor = LlavaOnevisionProcessor.from_pretrained(pytorch_dump_folder_path)
|
||||
device = model.device
|
||||
|
||||
# prepare inputs
|
||||
image = load_image()
|
||||
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch.float16)
|
||||
|
||||
# verify inputs
|
||||
filepath = hf_hub_download(
|
||||
repo_id="RaushanTurganbay/test-image", filename="llava_onevision_pixel_values.pt", repo_type="dataset"
|
||||
)
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu")
|
||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||
|
||||
image_sizes = torch.tensor([[899, 1024]])
|
||||
assert image_sizes[0].tolist() == inputs.image_sizes[0].tolist()
|
||||
|
||||
# verify single forward pass
|
||||
print("Single forward pass")
|
||||
with torch.inference_mode():
|
||||
inputs = inputs.to(device)
|
||||
outputs = model(**inputs)
|
||||
print("Shape of logits:", outputs.logits.shape)
|
||||
print("First values of logits:", outputs.logits[0, :3, :3])
|
||||
|
||||
if model_id == "lmms-lab/llava-onevision-qwen2-0.5b-si":
|
||||
# Not yet checked against reference
|
||||
expected_slice = torch.tensor(
|
||||
[[-12.1953, -14.6797, -12.7891], [0.5840, -0.8467, 1.3799], [3.6055, 4.5430, 9.9062]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-0.5b-ov":
|
||||
# Not yet checked against reference
|
||||
expected_slice = torch.tensor(
|
||||
[[-12.0234, -14.3828, -12.7500], [2.3594, 1.0000, 3.9336], [3.6582, 4.7148, 9.1172]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-7b-si":
|
||||
# Not yet checked against reference
|
||||
expected_slice = torch.tensor(
|
||||
[[1.7656, 3.3418, 1.4033], [0.0757, 0.7427, 3.5098], [6.7109, 5.6797, 9.3828]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-7b-ov":
|
||||
# Not yet checked against reference
|
||||
expected_slice = torch.tensor(
|
||||
[[1.8496, 3.4219, 1.3135], [3.0996, 3.0117, 3.1484], [4.2422, 4.7109, 9.9688]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-72b-si":
|
||||
# Not yet checked against reference
|
||||
expected_slice = torch.tensor(
|
||||
[[4.1875, 4.4883, 2.7910], [1.2949, 5.1328, 3.1582], [0.9390, 6.4531, 8.4375]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-72b-ov":
|
||||
# Not yet checked against reference
|
||||
expected_slice = torch.tensor(
|
||||
[[4.2930, 4.7305, 2.7363], [1.7529, 5.0742, 3.9590], [1.3936, 6.3438, 9.3984]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
|
||||
assert torch.allclose(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
|
||||
print("Logits are ok!")
|
||||
|
||||
# verify generation
|
||||
output_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
generated_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
||||
|
||||
print("Generated text:", repr(generated_text))
|
||||
|
||||
if model_id == "lmms-lab/llava-onevision-qwen2-0.5b-si":
|
||||
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image is a radar chart that shows the performance of different algorithms or models in a specific domain, such as image classification or natural language processing. The chart is color-coded to represent different algorithms, with each color corresponding to a specific algorithm. The algorithms are labeled as BLIP-2, InstructBLIP, Owen-VL-Chat, and LLaVA-1.5. The chart also includes a legend at the bottom that explains the color coding and the algorithms represented."
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-0.5b-ov":
|
||||
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into different categories, each represented by a different color and labeled with the name of the model or technique used. The models are evaluated based on their performance metrics, such as BLEU-2, InstructBLIP, Qwen-VL-Chat, and LLaVA-1.5. The radar chart helps to visualize the relative"
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-7b-si":
|
||||
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThis image is a radar chart that compares the performance of different models on various metrics. The models being compared are BLIP-2, InstructBLIP, and Qwen-VL-Chat. The metrics being compared are VQA, QA, GQA, VQA-av2, and VQA-av2. The chart shows that BLIP-2 performs the best on all metrics, followed by InstructBLIP and Qwen-VL-Chat."
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-7b-ov":
|
||||
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart, also known as a spider chart or a star chart, which is used to compare multiple quantitative variables. Each axis represents a different variable, and the chart is filled with data points that represent the performance or values of different entities across these variables.\n\nIn this particular radar chart, the variables are represented on the axes, and the performance of different models or systems is shown by the lines connecting the data points. The models or systems are labeled along the bottom of the chart,"
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-72b-si":
|
||||
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. The chart is used to compare the performance of different models or systems across various benchmarks or metrics.\n\nIn this specific radar chart, there are multiple axes, each representing a different benchmark or metric, such as VQA2, GQA, TextVQA, and others. The chart includes several colored lines"
|
||||
elif model_id == "lmms-lab/llava-onevision-qwen2-72b-ov":
|
||||
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image is a radar chart comparing the performance of different models on various multimodal benchmarks. The models compared are BLIP-2, InstructBLIP, POPE, QWen-VL-Chat, and LLava-1.5. The benchmarks include VQAv2, GQA, TextVQA, SQA-IMG, VizWiz, MM-IMDb, MM-VQA, MM-IMDb-CN, MM-IMDb-EN, MM-"
|
||||
else:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
|
||||
assert generated_text == expected_text
|
||||
print("Generated text is ok!")
|
||||
|
||||
# verify batched generation
|
||||
print("Batched generation...")
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
images=[image, cats_image],
|
||||
text=[prompt, prompt],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device, torch.float16)
|
||||
|
||||
for k, v in inputs.items():
|
||||
print(k, v.shape)
|
||||
|
||||
print("Image sizes:", inputs.image_sizes)
|
||||
|
||||
# make sure image_sizes are the same
|
||||
# as otherwise batched generation doesn't work
|
||||
inputs.image_sizes[1] = inputs.image_sizes[0]
|
||||
|
||||
print("Batched generation...")
|
||||
output_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=20,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
print(outputs)
|
||||
|
||||
if push_to_hub:
|
||||
checkpoint_name = model_id.split("/")[-1]
|
||||
print(f"Pushing to repo llava-hf/{checkpoint_name}-hf")
|
||||
model.push_to_hub(f"llava-hf/{checkpoint_name}-hf")
|
||||
processor.push_to_hub(f"llava-hf/{checkpoint_name}-hf")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
help="Hub location of the model to convert",
|
||||
default="lmms-lab/llava-onevision-qwen2-0.5b-ov",
|
||||
choices=[
|
||||
"lmms-lab/llava-onevision-qwen2-0.5b-ov",
|
||||
"lmms-lab/llava-onevision-qwen2-0.5b-si",
|
||||
"lmms-lab/llava-onevision-qwen2-7b-si",
|
||||
"lmms-lab/llava-onevision-qwen2-7b-ov",
|
||||
"lmms-lab/llava-onevision-qwen2-72b-si",
|
||||
"lmms-lab/llava-onevision-qwen2-72b-ov",
|
||||
],
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_llava_to_hf(args.model_id, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -0,0 +1,711 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for LLaVa-Onevision."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
|
||||
from ...image_transforms import (
|
||||
PaddingMode,
|
||||
convert_to_rgb,
|
||||
pad,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_vision_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images
|
||||
def make_batched_images(images) -> List[List[ImageInput]]:
|
||||
"""
|
||||
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
||||
|
||||
Args:
|
||||
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
||||
The input image.
|
||||
|
||||
Returns:
|
||||
list: A list of images.
|
||||
"""
|
||||
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
||||
return [img for img_list in images for img in img_list]
|
||||
|
||||
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
||||
return images
|
||||
|
||||
elif is_valid_image(images):
|
||||
return [images]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {images}")
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.divide_to_patches
|
||||
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
|
||||
"""
|
||||
Divides an image into patches of a specified size.
|
||||
|
||||
Args:
|
||||
image (`np.array`):
|
||||
The input image.
|
||||
patch_size (`int`):
|
||||
The size of each patch.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
list: A list of np.array representing the patches.
|
||||
"""
|
||||
patches = []
|
||||
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
patch = image[i : i + patch_size, j : j + patch_size]
|
||||
else:
|
||||
patch = image[:, i : i + patch_size, j : j + patch_size]
|
||||
patches.append(patch)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.expand_to_square
|
||||
def expand_to_square(image: np.array, background_color, input_data_format) -> np.array:
|
||||
"""
|
||||
Expands an image to a square by adding a background color.
|
||||
"""
|
||||
|
||||
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||
if width == height:
|
||||
return image
|
||||
elif width > height:
|
||||
result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
|
||||
result[(width - height) // 2 : (width - height) // 2 + height, :] = image
|
||||
return result
|
||||
else:
|
||||
result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
|
||||
result[:, (height - width) // 2 : (height - width) // 2 + width] = image
|
||||
return result
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next._get_patch_output_size
|
||||
def _get_patch_output_size(image, target_resolution, input_data_format):
|
||||
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
||||
target_height, target_width = target_resolution
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.ceil(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.ceil(original_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LLaVa-Onevisino-Video video processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||
method.
|
||||
image_grid_pinpoints (`List` *optional*, defaults to `[[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]`):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processinf videos.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
image_grid_pinpoints: List = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = True,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 384, "width": 384}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
image_grid_pinpoints = (
|
||||
image_grid_pinpoints
|
||||
if image_grid_pinpoints is not None
|
||||
else [
|
||||
[384, 384],
|
||||
[384, 768],
|
||||
[384, 1152],
|
||||
[384, 1536],
|
||||
[384, 1920],
|
||||
[384, 2304],
|
||||
[768, 384],
|
||||
[768, 768],
|
||||
[768, 1152],
|
||||
[768, 1536],
|
||||
[768, 1920],
|
||||
[768, 2304],
|
||||
[1152, 384],
|
||||
[1152, 768],
|
||||
[1152, 1152],
|
||||
[1152, 1536],
|
||||
[1152, 1920],
|
||||
[1152, 2304],
|
||||
[1536, 384],
|
||||
[1536, 768],
|
||||
[1536, 1152],
|
||||
[1536, 1536],
|
||||
[1536, 1920],
|
||||
[1536, 2304],
|
||||
[1920, 384],
|
||||
[1920, 768],
|
||||
[1920, 1152],
|
||||
[1920, 1536],
|
||||
[1920, 1920],
|
||||
[1920, 2304],
|
||||
[2304, 384],
|
||||
[2304, 768],
|
||||
[2304, 1152],
|
||||
[2304, 1536],
|
||||
[2304, 1920],
|
||||
[2304, 2304],
|
||||
]
|
||||
)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.image_grid_pinpoints = image_grid_pinpoints
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_pad = do_pad
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor.pad
|
||||
def pad(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
|
||||
mode: PaddingMode = PaddingMode.CONSTANT,
|
||||
constant_values: Union[float, Iterable[float]] = 0.0,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
|
||||
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
|
||||
as input.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to pad.
|
||||
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
|
||||
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
||||
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
||||
- `((before, after),)` yields same before and after pad for height and width.
|
||||
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
||||
mode (`PaddingMode`):
|
||||
The padding mode to use. Can be one of:
|
||||
- `"constant"`: pads with a constant value.
|
||||
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
||||
vector along each axis.
|
||||
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
||||
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
||||
constant_values (`float` or `Iterable[float]`, *optional*):
|
||||
The value to use for the padding if `mode` is `"constant"`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded image.
|
||||
|
||||
"""
|
||||
|
||||
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
|
||||
if isinstance(padding, int) or len(padding) != 4:
|
||||
return pad(image, padding, mode, constant_values, data_format, input_data_format)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if mode == PaddingMode.CONSTANT:
|
||||
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
|
||||
elif mode == PaddingMode.REFLECT:
|
||||
image = np.pad(image, padding, mode="reflect")
|
||||
elif mode == PaddingMode.REPLICATE:
|
||||
image = np.pad(image, padding, mode="edge")
|
||||
elif mode == PaddingMode.SYMMETRIC:
|
||||
image = np.pad(image, padding, mode="symmetric")
|
||||
else:
|
||||
raise ValueError(f"Invalid padding mode: {mode}")
|
||||
image = (
|
||||
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
||||
)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._resize_for_patching
|
||||
def _resize_for_patching(
|
||||
self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension
|
||||
) -> np.array:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image (np.array):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
resample (`PILImageResampling`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
np.array: The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
|
||||
|
||||
return resized_image
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
|
||||
def _pad_for_patching(
|
||||
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> np.array:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
|
||||
|
||||
return padded_image
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor.get_image_patches
|
||||
def get_image_patches(
|
||||
self,
|
||||
image: np.array,
|
||||
grid_pinpoints,
|
||||
size: tuple,
|
||||
patch_size: int,
|
||||
resample: PILImageResampling,
|
||||
data_format: ChannelDimension,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Process an image with variable resolutions by dividing it into patches.
|
||||
|
||||
Args:
|
||||
image (np.array):
|
||||
The input image to be processed.
|
||||
grid_pinpoints (List):
|
||||
A string representation of a list of possible resolutions.
|
||||
size (`tuple`):
|
||||
Size to resize the original image to.
|
||||
patch_size (`int`):
|
||||
Size of the patches to divide the image into.
|
||||
resample (`PILImageResampling`):
|
||||
Resampling filter to use if resizing the image.
|
||||
data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format for the output image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
List[np.array]: A list of NumPy arrays containing the processed image patches.
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints must be a list of possible resolutions.")
|
||||
|
||||
possible_resolutions = grid_pinpoints
|
||||
|
||||
image_size = get_image_size(image, channel_dim=input_data_format)
|
||||
best_resolution = select_best_resolution(image_size, possible_resolutions)
|
||||
resized_image = self._resize_for_patching(
|
||||
image, best_resolution, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
|
||||
|
||||
patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
|
||||
|
||||
# make sure that all patches are in the input data format
|
||||
patches = [
|
||||
to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
|
||||
for patch in patches
|
||||
]
|
||||
|
||||
resized_original_image = resize(
|
||||
image,
|
||||
size=size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
image_patches = [resized_original_image] + patches
|
||||
|
||||
return image_patches
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_batching
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: List[np.ndarray],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[np.ndarray]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
|
||||
Returns:
|
||||
List[`np.ndarray`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
self.pad(
|
||||
image,
|
||||
padding=((0, max_patch - image.shape[0]), (0, 0), (0, 0), (0, 0)),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Batch of frames (one video) to preprocess. Expects a batch of frames with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if do_resize:
|
||||
images = [
|
||||
resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
image_grid_pinpoints: List = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
image_grid_pinpoints (`List` *optional*, defaults to `self.image_grid_pinpoints`):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is
|
||||
selected based on the original size of the image.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else self.image_grid_pinpoints
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
images = make_batched_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
new_images = []
|
||||
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
|
||||
for image in images:
|
||||
# convert image into a list of patches
|
||||
# we intentially use the same data format as the input data format
|
||||
size_tuple = (
|
||||
(size["height"], size["width"])
|
||||
if "height" in size and "width" in size
|
||||
else (size["shortest_edge"], size["shortest_edge"])
|
||||
)
|
||||
image_patches = self.get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=size_tuple,
|
||||
patch_size=size["height"],
|
||||
resample=resample,
|
||||
data_format=input_data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
# preprocess patches
|
||||
pixel_values = self._preprocess(
|
||||
image_patches,
|
||||
do_resize=do_resize,
|
||||
size=size_tuple,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pixel_values = np.array(pixel_values)
|
||||
new_images.append(pixel_values)
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(new_images)
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
|
||||
)
|
@ -0,0 +1,727 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Llava-Onevision model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_llava_onevision import LlavaOnevisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LlavaNextConfig"
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.get_anyres_image_grid_shape
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
|
||||
)
|
||||
image_size = image_size.tolist()
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.image_size_to_num_patches
|
||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
"""
|
||||
Calculate the number of patches after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
|
||||
The size of the input image in the format (height, width). ?
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
int: the number of patches
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
|
||||
image_size = image_size.tolist()
|
||||
|
||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||
height, width = best_resolution
|
||||
num_patches = 0
|
||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
num_patches += 1
|
||||
# add the base patch
|
||||
num_patches += 1
|
||||
return num_patches
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.unpad_image
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaOnevision
|
||||
class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for LlavaOnevision causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
||||
sequence_length, hidden_size)`.
|
||||
|
||||
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision
|
||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaOnevisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
LLAVA_ONEVISION_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaVA-Onevision Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAVA_ONEVISION_START_DOCSTRING,
|
||||
)
|
||||
class LlavaOnevisionPreTrainedModel(PreTrainedModel):
|
||||
config_class = LlavaOnevisionConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlavaOnevisionVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of LlavaNext isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
LLAVA_ONEVISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
|
||||
[`LlavaNextImageProcessor`] for processing images.
|
||||
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
|
||||
The sizes of the images in the batch, being (height, width) for each image.
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, frames, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input videos. Pixel values can be obtained using
|
||||
[`LlavaNextVideoProcessor`]. See [`LlavaNextVideoProcessor.__call__`] for details. [`LlavaProcessor`] uses
|
||||
[`LlavaNextVideoProcessor`] for processing videos.
|
||||
image_sizes_videos (`torch.LongTensor` of shape `(batch_size, frames, 2)`, *optional*):
|
||||
The sizes of the videos in the batch, being (height, width) for each frame in the video.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
vision_feature_layer (`int`, *optional*, defaults to -2):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
||||
If `"full"`, the full vision features are used.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The LLaVA-Onevision model which consists of a vision backbone and a language model.""",
|
||||
LLAVA_ONEVISION_START_DOCSTRING,
|
||||
)
|
||||
class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel):
|
||||
def __init__(self, config: LlavaOnevisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(
|
||||
config.vision_config, attn_implementation=config._attn_implementation
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
|
||||
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
|
||||
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AutoModelForCausalLM.from_config(
|
||||
config.text_config, attn_implementation=config._attn_implementation
|
||||
)
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.tie_weights
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
|
||||
Args:
|
||||
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
|
||||
List of image feature tensor, each contains all the visual feature of all patches.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||||
New line embedding vector.
|
||||
vision_aspect_ratio (`str`, *optional*, "anyres_max_9"):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
Returns:
|
||||
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||||
feature_lens (`List[int]`)
|
||||
token length of each image in image_features
|
||||
"""
|
||||
new_image_features = []
|
||||
feature_lens = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
max_num_patches = int(vision_aspect_ratio.strip("anyres_max_"))
|
||||
channels, curr_height, curr_width = image_feature.shape
|
||||
ratio = math.sqrt(curr_height * curr_width / (max_num_patches * height**2))
|
||||
if ratio > 1.1:
|
||||
image_feature = image_feature[None]
|
||||
image_feature = nn.functional.interpolate(
|
||||
image_feature, [int(curr_height // ratio), int(curr_width // ratio)], mode="bilinear"
|
||||
)[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.device, image_feature.dtype),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
|
||||
def apply_pooling(self, image_features):
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
batch_frames, seq_len, dim = image_features.shape
|
||||
image_features = image_features.view(batch_frames, height, width, -1)
|
||||
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
height, weight = image_features.shape[2:]
|
||||
scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)]
|
||||
image_features = nn.functional.interpolate(image_features, size=scaled_shape, mode="bilinear")
|
||||
|
||||
image_features = image_features.permute(0, 2, 3, 1)
|
||||
image_features = image_features.view(batch_frames, -1, dim)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
pixel_values_videos: torch.FloatTensor = None,
|
||||
image_sizes_videos: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
|
||||
|
||||
Returns:
|
||||
[`~LlavaOnevisionCausalLMOutputWithPast`] (if `return_dict=True`) or a `tuple`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> import torch
|
||||
>>> from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration
|
||||
|
||||
>>> model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype="float16", device_map="cuda:0")
|
||||
>>> processor = LlavaOnevisionProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
|
||||
>>> conversation = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "text", "text": "What is shown in this image?"},
|
||||
... {"type": "image"},
|
||||
... ],
|
||||
... },
|
||||
... ]
|
||||
>>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
>>> image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
>>> inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.float16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
>>> processor.batch_decode(output, skip_special_tokens=True)[0]
|
||||
"user\n\nWhat is shown in this image?\nassistant\ncat"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_aspect_ratio = (
|
||||
vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values/pixel_values_videos and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# Images are processed with Anyres
|
||||
if pixel_values is not None:
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
|
||||
# unpad extra patches and concatenate them
|
||||
if pixel_values.dim() == 5:
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
# [batch_size*frames*num_patches, num_channels, height, width] where frames=1 for images
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# Video are simply embedded and further pooled to decrease seq len
|
||||
if pixel_values_videos is not None:
|
||||
batch_size, frames, channels, height, width = pixel_values_videos.shape
|
||||
pixel_values_videos = pixel_values_videos.view(batch_size * frames, channels, height, width)
|
||||
video_features = self.vision_tower(pixel_values_videos, output_hidden_states=True)
|
||||
selected_video_feature = video_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_video_feature = selected_video_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_video_feature = selected_video_feature
|
||||
video_features = self.multi_modal_projector(selected_video_feature)
|
||||
|
||||
video_features = self.apply_pooling(video_features)
|
||||
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
special_video_mask = (
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
shift_attention_mask = attention_mask[..., 1:]
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return LlavaOnevisionCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
pixel_values_videos=None,
|
||||
image_sizes_videos=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
model_inputs["image_sizes_videos"] = image_sizes_videos
|
||||
|
||||
return model_inputs
|
@ -0,0 +1,274 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for LLaVa-Onevision.
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Unpack
|
||||
else:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"image_kwargs": {},
|
||||
"video_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a LLaVa-Onevision processor which wraps a LLaVa-Onevision video processor, LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
|
||||
|
||||
[`LlavaNextProcessor`] offers all the functionalities of [`LlavaOnevisionVideoProcessor`], [`LlavaNextImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
||||
[`~LlavaOnevisionVideoProcessor.__call__`], [`~LlavaNextProcessor.__call__`] and [`~LlavaNextProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaNextImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
video_processor ([`LlavaOnevisionVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Shoudl be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
video_processor_class = "LlavaOnevisionVideoProcessor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
video_processor=None,
|
||||
num_image_tokens=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<image>",
|
||||
video_token="<video>",
|
||||
**kwargs: Unpack[LlavaOnevisionProcessorKwargs],
|
||||
):
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
LlavaOnevisionProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
image_inputs = video_inputs = {}
|
||||
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
image_sizes = iter(image_inputs["image_sizes"])
|
||||
height, width = get_image_size(
|
||||
to_numpy_array(image_inputs["pixel_values"][0][0]),
|
||||
channel_dim=output_kwargs["images_kwargs"].get("data_format"),
|
||||
)
|
||||
text = self._expand_image_tokens(text, image_sizes, height, width, self.image_token)
|
||||
|
||||
if videos is not None:
|
||||
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
|
||||
|
||||
one_video = to_numpy_array(video_inputs["pixel_values_videos"][0])
|
||||
height, width = get_image_size(one_video[0], channel_dim=output_kwargs["images_kwargs"].get("data_format"))
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
patches_height_width = int(math.sqrt(self.num_image_tokens))
|
||||
pooled_height_width = math.ceil(patches_height_width / 2)
|
||||
num_video_tokens = (num_frames * pooled_height_width * pooled_height_width) + 1 # +1 for newline token
|
||||
text = [sample.replace(self.video_token, self.video_token * num_video_tokens) for sample in text]
|
||||
|
||||
# Padding side can be in TextKwargs but is not accepted by the tokenizer
|
||||
_ = output_kwargs["text_kwargs"].pop("padding_side", None)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def _expand_image_tokens(
|
||||
self,
|
||||
text: List[TextInput],
|
||||
image_sizes: Iterable[Union[List[int], int]],
|
||||
height: int,
|
||||
width: int,
|
||||
special_token: str,
|
||||
num_frames: int = 1,
|
||||
):
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
while special_token in sample:
|
||||
image_size_list = next(image_sizes)
|
||||
orig_height, orig_width = image_size_list[0] if num_frames != 1 else image_size_list
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens * num_frames, 1)
|
||||
prompt_strings.append(sample)
|
||||
text = [sample.replace("<placeholder>", special_token) for sample in prompt_strings]
|
||||
return text
|
||||
|
||||
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
||||
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
|
||||
|
||||
height_best_resolution, width_best_resolution = select_best_resolution(
|
||||
[orig_height, orig_width], image_grid_pinpoints
|
||||
)
|
||||
scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
|
||||
|
||||
patches_height = patches_width = int(math.sqrt(self.num_image_tokens))
|
||||
unpadded_features, newline_features = self._get_unpadded_features(
|
||||
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
|
||||
)
|
||||
|
||||
# The base patch covers the entire image (no CLS for SigLIP)
|
||||
base_features = self.num_image_tokens
|
||||
num_image_tokens = unpadded_features + newline_features + base_features
|
||||
return num_image_tokens
|
||||
|
||||
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
|
||||
"""
|
||||
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
|
||||
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
|
||||
patches an image is divided into and get the number of features from that.
|
||||
"""
|
||||
current_height = patches_height * scale_height
|
||||
current_width = patches_width * scale_width
|
||||
|
||||
original_aspect_ratio = width / height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(height * (current_width / width))
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
else:
|
||||
new_width = int(width * (current_height / height))
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
ratio = math.sqrt(current_height * current_width / (9 * patches_height**2))
|
||||
if ratio > 1.1:
|
||||
unpadded_features = int(current_height // ratio) * int(current_width // ratio)
|
||||
newline_features = int(current_height // ratio)
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
@ -0,0 +1,335 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Video processor class for LLaVa-Onevision."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_vision_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def make_batched_videos(videos) -> List[VideoInput]:
|
||||
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||
return videos
|
||||
|
||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||
if isinstance(videos[0], Image.Image) or len(videos[0].shape) == 3:
|
||||
return [videos]
|
||||
elif len(videos[0].shape) == 4:
|
||||
return [list(video) for video in videos]
|
||||
|
||||
elif is_valid_image(videos) and len(videos.shape) == 4:
|
||||
return [list(videos)]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {videos}")
|
||||
|
||||
|
||||
class LlavaOnevisionVideoProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LLaVa-Onevisino-Video video processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 384, "width": 384}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Batch of frames (one video) to preprocess. Expects a batch of frames with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled videos. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
videos = make_batched_videos(videos)
|
||||
|
||||
if not valid_images(videos[0]):
|
||||
raise ValueError(
|
||||
"Invalid video type. Must be a list consisting of PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
size_tuple = (
|
||||
(size["height"], size["width"])
|
||||
if "height" in size and "width" in size
|
||||
else (size["shortest_edge"], size["shortest_edge"])
|
||||
)
|
||||
|
||||
pixel_values = [
|
||||
self._preprocess(
|
||||
video,
|
||||
do_resize=do_resize,
|
||||
size=size_tuple,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for video in videos
|
||||
]
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values_videos": pixel_values},
|
||||
tensor_type=return_tensors,
|
||||
)
|
@ -390,6 +390,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
del output["image_processor"]
|
||||
if "feature_extractor" in output:
|
||||
del output["feature_extractor"]
|
||||
if "chat_template" in output:
|
||||
del output["chat_template"]
|
||||
|
||||
# Some attributes have different names but containing objects that are not simple strings
|
||||
output = {
|
||||
|
@ -5374,6 +5374,20 @@ class LlavaNextVideoPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LlavaOnevisionForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LlavaOnevisionPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LongformerForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -373,6 +373,20 @@ class LlavaNextVideoImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LlavaOnevisionImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LlavaOnevisionVideoProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class Mask2FormerImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
@ -491,6 +491,7 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
@ -630,6 +631,7 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
@ -986,6 +988,7 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
@ -1152,6 +1155,7 @@ class GenerationTesterMixin:
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
@ -1216,6 +1220,7 @@ class GenerationTesterMixin:
|
||||
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
@ -1453,8 +1458,10 @@ class GenerationTesterMixin:
|
||||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# With left-padding (length 32)
|
||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||
pad_token_id = config.pad_token_id if getattr(config, "pad_token_id") is not None else 0
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
@ -1765,15 +1772,14 @@ class GenerationTesterMixin:
|
||||
}
|
||||
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
head_dim = (
|
||||
model.config.head_dim
|
||||
if hasattr(model.config, "head_dim")
|
||||
else model.config.hidden_size // model.config.num_attention_heads
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
model.config.num_attention_heads
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else model.config.num_key_value_heads
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@ -1922,6 +1928,7 @@ class GenerationTesterMixin:
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
gen_len = (
|
||||
|
0
tests/models/llava_onevision/__init__.py
Normal file
0
tests/models/llava_onevision/__init__.py
Normal file
@ -0,0 +1,291 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor
|
||||
|
||||
|
||||
class LlavaOnevisionImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=20,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# Copied from tests.models.llava_next_video.test_image_processing_llava_next_video.LlavaNextVideoProcessingTester.prepare_video_inputs
|
||||
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(8, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * 8)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = LlavaOnevisionImageProcessor if is_vision_available() else None
|
||||
video_processing_class = LlavaOnevisionVideoProcessor if is_vision_available() else None
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaOnevision
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = LlavaOnevisionImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "image_grid_pinpoints"))
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
image_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
@unittest.skip(
|
||||
reason="LlavaOnevisionImageProcessor doesn't treat 4 channel PIL and numpy consistently yet"
|
||||
) # FIXME raushan
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
def test_nested_input(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
|
||||
# Test batched as a list of images
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched as a nested list of images, where each sublist is one batch
|
||||
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
|
||||
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 1522, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
|
||||
|
||||
# Image processor should return same pixel values, independently of input format
|
||||
self.assertTrue((encoded_images_nested == encoded_images).all())
|
||||
|
||||
def test_call_pil_video(self):
|
||||
# Initialize image_processing
|
||||
video_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (7, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_numpy_video(self):
|
||||
# Initialize image_processing
|
||||
video_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (7, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch_video(self):
|
||||
# Initialize image_processing
|
||||
video_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, torchify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (7, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
523
tests/models/llava_onevision/test_modeling_llava_onevision.py
Normal file
523
tests/models/llava_onevision/test_modeling_llava_onevision.py
Normal file
@ -0,0 +1,523 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Llava-NeXT model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
LlavaOnevisionConfig,
|
||||
LlavaOnevisionForConditionalGeneration,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class LlavaOnevisionVisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
image_token_index=0,
|
||||
projector_hidden_act="gelu",
|
||||
seq_length=7,
|
||||
vision_feature_select_strategy="full",
|
||||
vision_feature_layer=-1,
|
||||
text_config={
|
||||
"model_type": "qwen2",
|
||||
"seq_length": 7,
|
||||
"is_training": True,
|
||||
"use_input_mask": True,
|
||||
"use_token_type_ids": False,
|
||||
"use_labels": True,
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 580,
|
||||
"type_vocab_size": 16,
|
||||
"type_sequence_label_size": 2,
|
||||
"initializer_range": 0.02,
|
||||
"num_labels": 3,
|
||||
"num_choices": 4,
|
||||
"pad_token_id": 0,
|
||||
},
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"image_size": 16,
|
||||
"patch_size": 2,
|
||||
"num_channels": 3,
|
||||
"is_training": True,
|
||||
"hidden_size": 32,
|
||||
"projection_dim": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"dropout": 0.1,
|
||||
"attention_dropout": 0.1,
|
||||
"initializer_range": 0.02,
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.seq_length = seq_length
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
self.is_training = is_training
|
||||
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 30
|
||||
self.encoder_seq_length = 7
|
||||
self.image_grid_pinpoints = [[32, 32]]
|
||||
|
||||
def get_config(self):
|
||||
return LlavaOnevisionConfig(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
ignore_index=self.ignore_index,
|
||||
image_token_index=self.image_token_index,
|
||||
projector_hidden_act=self.projector_hidden_act,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
image_grid_pinpoints=self.image_grid_pinpoints,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
9,
|
||||
self.vision_config["num_channels"],
|
||||
self.vision_config["image_size"],
|
||||
self.vision_config["image_size"],
|
||||
]
|
||||
)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
# we are giving 3 images let's make sure we pass in 3 image tokens
|
||||
input_ids[:, 1] = config.image_token_index
|
||||
labels = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
||||
# maskout where the image token is
|
||||
labels[:, 1] == self.ignore_index
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"image_sizes": torch.tensor(
|
||||
[[self.vision_config["image_size"], self.vision_config["image_size"]]] * self.batch_size
|
||||
),
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_llava_onevision_model_fp16_forward(
|
||||
self, config, input_ids, pixel_values, attention_mask, image_sizes
|
||||
):
|
||||
model = LlavaOnevisionForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
def create_and_check_llava_onevision_model_fp16_autocast_forward(
|
||||
self, config, input_ids, pixel_values, attention_mask, image_sizes
|
||||
):
|
||||
config.torch_dtype = torch.float16
|
||||
model = LlavaOnevisionForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
image_sizes=image_sizes,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `LlavaOnevisionForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LlavaOnevisionVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LlavaOnevisionConfig, has_text_modality=False)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
# LLaVa Onevision has SigLIP backbone which init weights differently from CLIP
|
||||
if "image_newline" in name or "vision_tower" in name:
|
||||
continue
|
||||
elif param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, SiglipVisionModel does not support standalone training"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, SiglipVisionModel does not support standalone training"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, SiglipVisionModel does not support standalone training"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("VLMs can't do assisted decoding yet!")
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", padding_side="left"
|
||||
)
|
||||
image_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/images_test", filename="llava_v1_5_radar.jpg", repo_type="dataset"
|
||||
)
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
)
|
||||
self.image = Image.open(image_file)
|
||||
self.video = np.load(video_file)
|
||||
self.prompt_image = "user\n<image>\nWhat do you see in this image?<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.prompt_video = "user\n<video>\nWhat do you see in this video?<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", torch_dtype="float16", device_map=torch_device
|
||||
)
|
||||
|
||||
inputs = self.processor(images=self.image, text=self.prompt_image, return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
self.assertTrue(inputs.input_ids.shape[1] == 6567) # should expand num-image-tokens times
|
||||
self.assertTrue(inputs.pixel_values.shape == torch.Size([1, 10, 3, 384, 384]))
|
||||
self.assertTrue(inputs.image_sizes.tolist() == [[899, 1024]])
|
||||
|
||||
# verify single forward pass
|
||||
inputs = inputs.to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-12.3125, -14.5625, -12.8750], [3.4023, 5.0508, 9.5469], [3.5762, 4.4922, 7.8906]],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output.logits[0, :3, :3], expected_slice, atol=1e-3))
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=100)
|
||||
EXPECTED_DECODED_TEXT = 'user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related to natural language processing or machine learning. The chart is divided into several axes, each representing a different model or method. The models are color-coded and labeled with their respective names. The axes are labeled with terms such as "VQA," "GQA," "MQA," "VIZ," "TextVQA," "SQA-IMG," and "MQE." The radar chart shows' # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch(self):
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", torch_dtype="float16", device_map=torch_device
|
||||
)
|
||||
|
||||
inputs = self.processor(
|
||||
text=[self.prompt_image, self.prompt_video],
|
||||
images=self.image,
|
||||
videos=self.video,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, torch.float16)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['user\n\nWhat do you see in this image?\nassistant\nThe image is a radar chart that compares the performance of different models in a specific task, likely related', 'user\n\nWhat do you see in this video?\nassistant\nA child wearing a light blue sleeveless top and pink pants is seen sitting on a bed, eng'] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_video(self):
|
||||
# related to (#29835)
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
torch_dtype="float16",
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
inputs = self.processor(text=self.prompt_video, videos=self.video, return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=40)
|
||||
EXPECTED_DECODED_TEXT = 'user\n\nWhat do you see in this video?\nassistant\nA child wearing a light blue sleeveless top and pink pants is seen sitting on a bed, engrossed in reading a book.' # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_multi_image(self):
|
||||
# related to (#29835)
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
torch_dtype="float16",
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = (
|
||||
"user\n<image><image>\nWhat is the difference between these images?<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = self.processor(text=prompt, images=[self.image, image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=40)
|
||||
EXPECTED_DECODED_TEXT = "user\n\nWhat is the difference between these images?\nassistant\nThe images you've provided appear to be related to a graphical representation of a radar chart, which is a type of data visualization used to show the distribution of a particular variable across a geographic area. The" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_multi_video(self):
|
||||
# related to (#29835)
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
torch_dtype="float16",
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
prompt = "user\n<video><video>\nAre these videos identical?<|im_end|>\n<|im_start|>assistant\n"
|
||||
inputs = self.processor(text=prompt, videos=[self.video, self.video], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=40)
|
||||
EXPECTED_DECODED_TEXT = "user\n\nAre these videos identical?\nassistant\nNo, the video is not identical; it shows slight variations in the child's actions and the background." # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch_different_resolutions(self):
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", torch_dtype="float16", device_map=torch_device
|
||||
)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs = self.processor(
|
||||
text=[self.prompt_image, self.prompt_image],
|
||||
images=[lowres_img, cats_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, torch.float16)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
EXPECTED_DECODED_TEXT = ['user\n\nWhat do you see in this image?\nassistant\nThe image shows a scene from a wildlife camera, likely a security camera, capturing a moment in a natural setting. It features two deer, one larger and one smaller, grazing on the grass. The environment is foggy, suggesting early morning or late', 'user\n\nWhat do you see in this image?\nassistant\nIn the tranquil setting of this image, two cats are enjoying a peaceful nap on a vibrant pink blanket. The cat on the left, with its gray and black striped fur, is lying on its side, its head comfortably resting on the blanket. Its'] # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batch_matches_single(self):
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
torch_dtype="float16",
|
||||
device_map=torch_device,
|
||||
)
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
text=[self.prompt_image, self.prompt_image],
|
||||
images=[lowres_img, cats_image],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device, torch.float16)
|
||||
|
||||
inputs_single = self.processor(
|
||||
text=self.prompt_image, images=lowres_img, return_tensors="pt", padding=True
|
||||
).to(torch_device, torch.float16)
|
||||
|
||||
# verify generation
|
||||
output_batched = model.generate(**inputs_batched, max_new_tokens=50)
|
||||
output_single = model.generate(**inputs_single, max_new_tokens=50)
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output_batched[0], skip_special_tokens=True),
|
||||
self.processor.decode(output_single[0], skip_special_tokens=True),
|
||||
)
|
277
tests/models/llava_onevision/test_processing_llava_onevision.py
Normal file
277
tests/models/llava_onevision/test_processing_llava_onevision.py
Normal file
@ -0,0 +1,277 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
LlavaOnevisionImageProcessor,
|
||||
LlavaOnevisionProcessor,
|
||||
LlavaOnevisionVideoProcessor,
|
||||
Qwen2TokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
@require_vision
|
||||
class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = LlavaOnevisionProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = LlavaOnevisionImageProcessor()
|
||||
video_processor = LlavaOnevisionVideoProcessor()
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
|
||||
processor = LlavaOnevisionProcessor(
|
||||
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_Video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_chat_template(self):
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
expected_prompt = "<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
# Rewrite as llava-next image processor return pixel values with an added dimesion for image patches
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", size=(234, 234))
|
||||
video_processor = self.get_component("video_processor", size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
# added dimension for image patches
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 234)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234))
|
||||
video_processor = self.get_component("video_processor", size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, size=[224, 224])
|
||||
# added dimension for image patches
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 224)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
# added dimension for image patches
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 5)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_doubly_passed_kwargs(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
image_input = self.prepare_image_inputs()
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
images_kwargs={"size": {"height": 222, "width": 222}},
|
||||
size={"height": 214, "width": 214},
|
||||
)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=112)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
@ -4808,6 +4808,8 @@ class ModelTesterMixin:
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
# some models have labels but `num_logits_to_keep` should not be used in train mode
|
||||
_ = inputs.pop("labels", None)
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, num_logits_to_keep=0).logits
|
||||
|
Loading…
Reference in New Issue
Block a user