VLMs: fix number of image tokens (#34332)

* fix

* fix tests

* add tests

* style

* style

* fix qwen after rebase

* fix video llava
This commit is contained in:
Raushan Turganbay 2024-10-30 10:21:37 +01:00 committed by GitHub
parent 0f764a5af7
commit 913330ca9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 237 additions and 15 deletions

View File

@ -1288,7 +1288,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values) image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0] n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features: if n_image_tokens_in_text != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"

View File

@ -527,8 +527,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
# TODO: @raushan retain only the new behavior after v4.47 # TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None: elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -1020,6 +1020,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
if image_features is not None: if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -533,6 +533,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
if image_features is not None: if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -679,6 +679,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
) )
n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -704,6 +705,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
) )
video_features = torch.cat((video_features, image_newline), dim=1) video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1) video_features = video_features.flatten(0, 1)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features: if n_video_tokens != n_video_features:

View File

@ -1503,13 +1503,14 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
mrope_position_deltas = [] mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None: if image_grid_thw is not None or video_grid_thw is not None:
total_input_ids = input_ids total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones( position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
) )
image_index, video_index = 0, 0 image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids): for i, input_ids in enumerate(total_input_ids):
if attention_mask is not None: input_ids = input_ids[attention_mask[i] == 1]
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0 image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1] vision_tokens = input_ids[vision_start_indices + 1]

View File

@ -628,8 +628,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
# TODO: @raushan retain only the new behavior after v4.47 # TODO: @raushan retain only the new behavior after v4.47
else: else:
if pixel_values_images is not None: if pixel_values_images is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -644,8 +644,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None: if pixel_values_videos is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[1] n_video_features = video_features.shape[0] * video_features.shape[1]
if n_video_tokens != n_video_features: if n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -517,8 +517,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
# TODO: @raushan retain only the new behavior after v4.47 # TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None: elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -235,6 +235,35 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)
@unittest.skip( @unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )

View File

@ -283,6 +283,38 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
image_sizes = input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
@unittest.skip( @unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )

View File

@ -303,6 +303,38 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
input_dict["image_sizes"] = input_dict["image_sizes"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
image_sizes = input_dict["image_sizes"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
@unittest.skip( @unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )

View File

@ -236,6 +236,36 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)
@unittest.skip( @unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )

View File

@ -58,7 +58,7 @@ class Qwen2VLVisionText2TextModelTester:
def __init__( def __init__(
self, self,
parent, parent,
batch_size=2, batch_size=3,
seq_length=7, seq_length=7,
num_channels=3, num_channels=3,
ignore_index=-100, ignore_index=-100,
@ -245,6 +245,40 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
patch_size = config.vision_config.patch_size
one_img_length = (self.model_tester.image_size**2) // (patch_size**2)
input_dict["pixel_values"] = input_dict["pixel_values"][-one_img_length:, ...]
input_dict["image_grid_thw"] = input_dict["image_grid_thw"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:one_img_length]
image_grid_thw = input_dict["image_grid_thw"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
@unittest.skip( @unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )

View File

@ -123,9 +123,9 @@ class VideoLlavaVisionText2TextModelTester:
self.batch_size = 5 self.batch_size = 5
self.num_channels = 3 self.num_channels = 3
self.image_size = 224 self.image_size = 224
self.encoder_seq_length = 64 self.encoder_seq_length = 246
self.num_image_tokens = 25 self.num_image_tokens = 25
self.num_video_tokens = 26 self.num_video_tokens = 26 * self.num_frames
self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens self.seq_length = seq_length + self.num_image_tokens + self.num_video_tokens
def get_config(self): def get_config(self):
@ -267,7 +267,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
# if we remove some images from inputs leaving only one # if we remove some images from inputs leaving only one
# image number mismatch error should raise # image number mismatch error should raise
inputs["pixel_values_images"] = inputs["pixel_values_images"][:1] inputs["pixel_values_images"] = inputs["pixel_values_images"][:1]
with self.assertRaises(RuntimeError): with self.assertRaises(ValueError):
_ = model(**inputs) _ = model(**inputs)
def test_video_only_input(self): def test_video_only_input(self):
@ -401,6 +401,35 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values_images"] = input_dict["pixel_values_images"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values_images"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values_images=pixel_values)
@require_torch @require_torch
class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -217,6 +217,36 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successfull forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)
@unittest.skip( @unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )