mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
VLMs: fix number of image tokens (#34332)
* fix * fix tests * add tests * style * style * fix qwen after rebase * fix video llava
This commit is contained in:
parent
0f764a5af7
commit
913330ca9f
@ -1288,7 +1288,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
|
||||
n_image_features = image_tokens.shape[0]
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
if n_image_tokens_in_text != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
|
||||
|
@ -527,8 +527,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
elif image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -1020,6 +1020,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
if image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -533,6 +533,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
if image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -679,6 +679,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
)
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -704,6 +705,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
|
@ -1503,13 +1503,14 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
mrope_position_deltas = []
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
total_input_ids = input_ids
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(total_input_ids)
|
||||
position_ids = torch.ones(
|
||||
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
image_index, video_index = 0, 0
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
if attention_mask is not None:
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
image_nums, video_nums = 0, 0
|
||||
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
|
@ -628,8 +628,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if pixel_values_images is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -644,8 +644,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
|
||||
n_video_features = video_features.shape[1]
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0] * video_features.shape[1]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -517,8 +517,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
elif image_features is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -235,6 +235,35 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -283,6 +283,38 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
image_sizes = input_dict["image_sizes"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -303,6 +303,38 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
image_sizes = input_dict["image_sizes"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -236,6 +236,36 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -58,7 +58,7 @@ class Qwen2VLVisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
batch_size=3,
|
||||
seq_length=7,
|
||||
num_channels=3,
|
||||
ignore_index=-100,
|
||||
@ -245,6 +245,40 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
patch_size = config.vision_config.patch_size
|
||||
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
|
||||
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:one_img_length]
|
||||
image_grid_thw = input_dict["image_grid_thw"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -123,9 +123,9 @@ class VideoLlavaVisionText2TextModelTester:
|
||||
self.batch_size = 5
|
||||
self.num_channels = 3
|
||||
self.image_size = 224
|
||||
self.encoder_seq_length = 64
|
||||
self.encoder_seq_length = 246
|
||||
self.num_image_tokens = 25
|
||||
self.num_video_tokens = 26
|
||||
self.num_video_tokens = 26 * self.num_frames
|
||||
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
|
||||
|
||||
def get_config(self):
|
||||
@ -267,7 +267,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
# if we remove some images from inputs leaving only one
|
||||
# image number mismatch error should raise
|
||||
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_video_only_input(self):
|
||||
@ -401,6 +401,35 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values_images"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -217,6 +217,36 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successfull forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user