transformers/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Raushan Turganbay 913330ca9f
VLMs: fix number of image tokens (#34332)
* fix

* fix tests

* add tests

* style

* style

* fix qwen after rebase

* fix video llava
2024-10-30 10:21:37 +01:00

547 lines
23 KiB
Python

# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Qwen2-VL model."""
import gc
import unittest
import requests
from transformers import (
AutoProcessor,
Qwen2VLConfig,
Qwen2VLForConditionalGeneration,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
)
if is_torch_available():
import torch
else:
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
class Qwen2VLVisionText2TextModelTester:
def __init__(
self,
parent,
batch_size=3,
seq_length=7,
num_channels=3,
ignore_index=-100,
image_size=14,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
vision_start_token_id=151652,
image_token_id=151655,
video_token_id=151656,
hidden_act="silu",
hidden_size=32,
vocab_size=152064,
intermediate_size=37,
max_position_embeddings=512,
max_window_layers=3,
model_type="qwen2_vl",
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rope_theta=10000,
tie_word_embeddings=True,
is_training=True,
vision_config={
"depth": 2,
"embed_dim": 32,
"hidden_act": "quick_gelu",
"hidden_size": 32,
"mlp_ratio": 4,
"num_heads": 4,
"patch_size": 14,
"spatial_merge_size": 1,
"temporal_patch_size": 2,
},
rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]},
):
self.parent = parent
self.ignore_index = ignore_index
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.vision_start_token_id = vision_start_token_id
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.max_window_layers = max_window_layers
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_key_value_heads = num_key_value_heads
self.rope_theta = rope_theta
self.tie_word_embeddings = tie_word_embeddings
self.vision_config = vision_config
self.rope_scaling = rope_scaling
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.is_training = is_training
self.vocab_size = vocab_size
self.num_image_tokens = 32
self.seq_length = seq_length + self.num_image_tokens
def get_config(self):
return Qwen2VLConfig(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
hidden_act=self.hidden_act,
max_position_embeddings=self.max_position_embeddings,
vision_config=self.vision_config,
model_type=self.model_type,
max_window_layers=self.max_window_layers,
rope_scaling=self.rope_scaling,
tie_word_embeddings=self.tie_word_embeddings,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
vision_start_token_id=self.vision_start_token_id,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
vocab_size=self.vocab_size,
)
def prepare_config_and_inputs(self):
config = self.get_config()
patch_size = config.vision_config.patch_size
temporal_patch_size = config.vision_config.temporal_patch_size
pixel_values = floats_tensor(
[
self.batch_size * (self.image_size**2) // (patch_size**2),
self.num_channels * (patch_size**2) * temporal_patch_size,
]
)
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length),
dtype=torch.long,
device=torch_device,
)
inputs_dict = {
"pixel_values": pixel_values,
"image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size),
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
return config, inputs_dict
def create_and_check_qwen2_vl_model_fp16_forward(
self, config, input_ids, pixel_values, attention_mask, image_grid_thw
):
model = Qwen2VLForConditionalGeneration(config=config)
model.to(torch_device)
model.half()
model.eval()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
def create_and_check_qwen2_vl_model_fp16_autocast_forward(
self, config, input_ids, pixel_values, attention_mask, image_grid_thw
):
config.torch_dtype = torch.float16
model = Qwen2VLForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_grid_thw=image_grid_thw,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch
class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `Qwen2VLForConditionalGeneration`.
"""
all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = Qwen2VLVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Qwen2VLConfig, has_text_modality=False)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:one_img_length]
image_grid_thw = input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@unittest.skip(reason="CPU offload is not yet supported")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_disk_offload_safetensors(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_model_parallelism(self):
pass
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="We cannot configure to output a smaller model.")
def test_model_is_small(self):
pass
@unittest.skip(
reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
)
def test_beam_search_low_memory(self):
pass
@unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@require_torch
class Qwen2VLIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
self.messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What kind of dog is this?"},
],
}
]
url = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg"
self.image = Image.open(requests.get(url, stream=True).raw)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
def test_small_model_integration_test(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text], images=[self.image], return_tensors="pt")
expected_input_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151655] # fmt: skip
assert expected_input_ids == inputs.input_ids[0].tolist()[:17]
expected_pixel_slice = torch.tensor(
[
[0.8792, 0.8792, 0.9084],
[1.1858, 1.1858, 1.2296],
[1.2004, 1.2004, 1.2150],
[1.4340, 1.4340, 1.4194],
[1.3902, 1.4048, 1.4194],
[1.5216, 1.5362, 1.5362],
],
dtype=torch.float32,
device="cpu",
)
assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3)
# verify generation
inputs = inputs.to(torch_device)
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets"
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_batch(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to(
torch_device
)
# it should not matter whether two images are the same size or not
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = [
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets'
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_batch_wo_image(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
messages2 = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who are you?"},
]
text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to(
torch_device
)
# it should not matter whether two images are the same size or not
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = [
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets',
'system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to assist with various tasks and answer questions to the best of my'
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_batch_different_resolutions(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
text2 = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
image2 = self.image.resize((224, 224))
inputs = self.processor(text=[text, text2], images=[self.image, image2], padding=True, return_tensors="pt").to(
torch_device
)
# it should not matter whether two images are the same size or not
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = [
"system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
"system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
]
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_flash_attn
@require_torch_gpu
def test_small_model_integration_test_batch_flashatt2(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to(
torch_device
)
# it should not matter whether two images are the same size or not
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = [
"system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
"system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
]
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True)[0],
self.processor.batch_decode(output, skip_special_tokens=True)[1],
)
@slow
@require_flash_attn
@require_torch_gpu
def test_small_model_integration_test_batch_wo_image_flashatt2(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
messages2 = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who are you?"},
]
text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to(
torch_device
)
# it should not matter whether two images are the same size or not
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = [
"system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
"system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to answer a wide range of questions and provide information on various topics",
]
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)